aboutsummaryrefslogtreecommitdiff
path: root/src/hamt.c
blob: a6be3ba976d37f5a11dfc33ee7d5fa59b0472c97 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
#include <stdio.h>
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#include <assert.h>
#include "hamt.h"

#include "mempool.h"
MEMPOOL_GENERATE(hamt, struct hamt, 16)
MEMPOOL_GENERATE(hl,   struct hamt_nodelist, 64)
MEMPOOL_GENERATE(hi,   struct hamt_item, 32)

#define tag_ptr(ptr, tag)   ((uintptr_t)(ptr) | (tag))
#define untag_ptr(ptr, tag) ((uintptr_t)(ptr) & ~(tag))
#define is_tagged(ptr, tag) ((uintptr_t)(ptr) & (tag))

#define HAMT_ITEM_TAG 0x1
#define TAG_NODELIST(ptr) \
    (hamtptr_t)untag_ptr(ptr, HAMT_ITEM_TAG)
#define TAG_ITEM(ptr)     \
    (hamtptr_t)tag_ptr(ptr, HAMT_ITEM_TAG)
#define AS_NODELIST(ptr) \
    ((struct hamt_nodelist *)untag_ptr(ptr, HAMT_ITEM_TAG))
#define AS_ITEM(ptr)     \
    ((struct hamt_item *)untag_ptr(ptr, HAMT_ITEM_TAG))
#define AS_VOIDPTR(ptr)  \
    ((void *)untag_ptr(ptr, HAMT_ITEM_TAG)
#define IS_NODELIST(hamtptr) \
    !is_tagged(hamtptr, HAMT_ITEM_TAG)
#define IS_ITEM(hamtptr)     \
    is_tagged(hamtptr, HAMT_ITEM_TAG)

#define for_each_item(item, head)                        \
    for(struct hamt_item *item = (head), *next = NULL;   \
        item && (next = item->next, 1); item = next)

#define popcount(i) __builtin_popcount(i)

#define BITS 6
#define BITS_MASK 0x3f

static inline struct hamt_nodelist *hamt_nodelist_alloc(void);
static inline struct hamt_item *hamt_item_alloc(void);

static void hamtptr_destroy(hamtptr_t hamtptr);
static inline void hamtptr_add_ref(hamtptr_t hamtptr);

// static void hamt_print_hamtptr(hamtptr_t hamtptr, int depth);

hamt_t hamt_create(hamt_equal_fn equal_fn, hamt_hash_fn hash_fn)
{
    // hamt_t hamt = malloc(sizeof(*hamt));
    hamt_t hamt = hamt_mempool_allocate();
    hamt->equal_fn = equal_fn;
    hamt->hash_fn  = hash_fn;
    hamt->root = TAG_NODELIST(hamt_nodelist_alloc());

    return hamt;
}

void hamt_destroy(hamt_t hamt)
{
    if(!hamt) return;

    hamtptr_destroy(hamt->root);
    // free(hamt);
    hamt_mempool_free(hamt);
}

hamt_t hamt_clone(hamt_t src)
{
    // hamt_t hamt = malloc(sizeof(*hamt));
    hamt_t hamt = hamt_mempool_allocate();

    if(!src) return hamt;

    hamt->equal_fn = src->equal_fn;
    hamt->hash_fn  = src->hash_fn;
    hamt->root     = src->root;
    hamtptr_add_ref(hamt->root);

    return hamt;
}

static int hamt_find_hamtptr(hamtptr_t root, hamtptr_t *ret, uint32_t *hash)
{
    *ret = root;

    for(size_t i = 0; i < sizeof(*hash)*8/BITS; i++) {
        assert(IS_NODELIST(*ret));
        struct hamt_nodelist *nodelist = AS_NODELIST(*ret);

        size_t rawidx = *hash & BITS_MASK;
        if(!(nodelist->bitmask & (1 << rawidx))) {
            return sizeof(*hash)*8/BITS - i;
        }

        *hash >>= BITS;

        size_t idx = popcount(nodelist->bitmask & ((1 << rawidx) - 1));
        *ret = nodelist->list[idx];
    }

    return 0;
}

int hamt_get(hamt_t hamt, void *key, void **data)
{
    hamtptr_t hamtptr;
    uint32_t hash = hamt->hash_fn(key);

    if(hamt_find_hamtptr(hamt->root, &hamtptr, &hash) != 0)
        return 1;

    assert(IS_ITEM(hamtptr));
    struct hamt_item *item = AS_ITEM(hamtptr);

    while(!hamt->equal_fn(key, item->key)) {
        if(item->next == NULL) return 1;
        item = item->next;
    }

    *data = item->data;
    return 0;
}

static hamtptr_t hamt_build(uint32_t hash, size_t iter, struct hamt_item **item)
{
      if(iter == 1) {
        *item = hamt_item_alloc();
        return TAG_ITEM(*item);
    }

    hash >>= BITS;
    size_t next_idx = hash & BITS_MASK;

    hamtptr_t next = hamt_build(hash, iter-1, item);

    struct hamt_nodelist *nodelist = hamt_nodelist_alloc();

    // nodelist->list = &next;
    nodelist->list = calloc(1, sizeof(*nodelist->list));
    nodelist->list[0] = next;

    nodelist->bitmask = 1 << next_idx;

    return TAG_NODELIST(nodelist);
}

static hamtptr_t hamt_insert_hamtptr(hamtptr_t root, hamtptr_t hamtptr, uint32_t hash, hamt_equal_fn equal_fn)
{
    if(IS_ITEM(root)) {
        struct hamt_item *item = AS_ITEM(root);

        if(equal_fn(item->key, AS_ITEM(hamtptr)->key)) {
            item->refs--;
            AS_ITEM(hamtptr)->next = item->next;
            return hamtptr;
        }

        if(item->next) {
            hamtptr = hamt_insert_hamtptr(TAG_ITEM(item->next), hamtptr, hash, equal_fn);
        }

        struct hamt_item *new;

        if(item->refs == 1) {
            new = item;
        } else {
            item->refs--;
            new = hamt_item_alloc();
            new->key = item->key;
            new->data = item->data;
        }

        new->next = AS_ITEM(hamtptr);
        return TAG_ITEM(new);
    }

    struct hamt_nodelist *nodelist = AS_NODELIST(root);

    size_t rawidx = hash & BITS_MASK;
    size_t idx = popcount(nodelist->bitmask & ((1 << rawidx) - 1));

    size_t list_len = popcount(nodelist->bitmask);
    size_t newlist_len = list_len;

    if(nodelist->bitmask & (1 << rawidx)) {
        hamtptr = hamt_insert_hamtptr(nodelist->list[idx], hamtptr, hash >> BITS, equal_fn);

        if(nodelist->refs == 1) {
            nodelist->list[idx] = hamtptr;
            return TAG_NODELIST(nodelist);
        }
    } else {
        newlist_len++;
    }

    hamtptr_t *newlist = calloc(newlist_len, sizeof(*newlist));
    newlist[idx] = hamtptr;

    for(size_t i = 0; i < idx; i++)
        newlist[i] = nodelist->list[i];

    if(list_len == newlist_len)
        for(size_t i = idx+1; i < newlist_len; i++)
            newlist[i] = nodelist->list[i];
    else
        for(size_t i = idx+1; i < newlist_len; i++)
            newlist[i] = nodelist->list[i-1];

    struct hamt_nodelist *new;

    if(nodelist->refs == 1) {
        new = nodelist;
        free(nodelist->list);
    } else {
        nodelist->refs--;
        new = hamt_nodelist_alloc();
    }

    new->list = newlist;
    new->bitmask = nodelist->bitmask | (1 << rawidx);
    return TAG_NODELIST(new);
}

int hamt_set(hamt_t hamt, void *key, void *data, void **keyptr, void **prevdata)
{
    hamtptr_t hamtptr;
    uint32_t hash = hamt->hash_fn(key);
    uint32_t hash_cpy = hash;

    size_t iter;
    if((iter = hamt_find_hamtptr(hamt->root, &hamtptr, &hash_cpy)) != 0) {
        struct hamt_item *item;
        hamtptr_t new = hamt_build(hash_cpy, iter, &item);

        hamt->root = hamt_insert_hamtptr(hamt->root, new, hash, hamt->equal_fn);

        item->key = key;
        item->data = data;
        return 0;
    }

    assert(IS_ITEM(hamtptr));

    for_each_item(item, AS_ITEM(hamtptr)) {
        if(hamt->equal_fn(item->key, key)) {
            if(keyptr)   *keyptr   = item->key;
            if(prevdata) *prevdata = item->data;
            if(item->refs == 1) {
                item->data = data;
                return 0;
            }

            key = item->key;
            break;
        }
    }

    struct hamt_item *new = hamt_item_alloc();
    new->key = key;
    new->data = data;

    hamt->root = hamt_insert_hamtptr(hamt->root, TAG_ITEM(new), hash, hamt->equal_fn);
    return 0;
}

static inline struct hamt_nodelist *hamt_nodelist_alloc(void)
{
    // struct hamt_nodelist *nodelist = malloc(sizeof(*nodelist));
    struct hamt_nodelist *nodelist = hl_mempool_allocate();
    nodelist->refs = 1;
    nodelist->list = 0;
    nodelist->bitmask = 0;
    return nodelist;
}

static inline struct hamt_item *hamt_item_alloc(void)
{
    // struct hamt_item *item = malloc(sizeof(*item));
    struct hamt_item *item = hi_mempool_allocate();
    item->refs = 1;
    item->key = NULL;
    item->data = NULL;
    item->next = NULL;
    return item;
}

static inline void hamtptr_destroy(hamtptr_t hamtptr)
{
    if(IS_NODELIST(hamtptr)) {
        struct hamt_nodelist *nodelist = AS_NODELIST(hamtptr);

        for(size_t i = 0; i < popcount(nodelist->bitmask); i++)
            hamtptr_destroy(nodelist->list[i]);

        if(--nodelist->refs == 0) {
            free(nodelist->list);
            // free(nodelist);
            hl_mempool_free(nodelist);
        }
    } else {
        for_each_item(item, AS_ITEM(hamtptr))
            if(--item->refs == 0)
                hi_mempool_free(item);
                // free(item);
    }
}

static inline void hamtptr_add_ref(hamtptr_t hamtptr)
{
    if(IS_NODELIST(hamtptr)) {
        for(size_t i = 0; i < popcount(AS_NODELIST(hamtptr)->bitmask); i++)
            hamtptr_add_ref(AS_NODELIST(hamtptr)->list[i]);
        AS_NODELIST(hamtptr)->refs++;
    } else {
        for_each_item(item, AS_ITEM(hamtptr))
            item->refs++;
    }
}

// static void hamt_print_hamtptr(hamtptr_t hamtptr, int depth)
// {
//     for(int i = 0; i < depth; i++) printf("  ");

//     if(IS_NODELIST(hamtptr)) {
//         printf("%d, MASK %b\n", AS_NODELIST(hamtptr)->refs, AS_NODELIST(hamtptr)->bitmask);
//         for(size_t i = 0; i < popcount(AS_NODELIST(hamtptr)->bitmask); i++) {
//             hamt_print_hamtptr(AS_NODELIST(hamtptr)->list[i], depth+1);
//         }
//     } else {
//         printf("%d, %s: %s\n", AS_ITEM(hamtptr)->refs, (char *)AS_ITEM(hamtptr)->key, (char *)AS_ITEM(hamtptr)->data);
//         if(AS_ITEM(hamtptr)->next)
//             hamt_print_hamtptr(TAG_ITEM(AS_ITEM(hamtptr)->next), depth+1);
//     }
// }