Skip to main content

osiris/types/
bitset.rs

1//! A simple bitset allocator that can be used to allocate contiguous runs of bits.
2
3use crate::types::array::Vec;
4
5pub struct BitAlloc<const N: usize> {
6    l1: Vec<usize, N>,
7}
8
9impl<const N: usize> BitAlloc<N> {
10    pub const BITS_PER_WORD: usize = usize::BITS as usize;
11
12    pub fn new(free_count: usize) -> Option<Self> {
13        let mut l1 = Vec::new();
14        let words = free_count.div_ceil(Self::BITS_PER_WORD);
15
16        for i in 0..words {
17            let rem = free_count.saturating_sub(i * Self::BITS_PER_WORD);
18            if rem >= Self::BITS_PER_WORD {
19                l1.push(!0usize).ok()?;
20            } else {
21                l1.push((!0usize).unbounded_shl((Self::BITS_PER_WORD - rem) as u32))
22                    .ok()?;
23            }
24        }
25
26        Some(Self { l1 })
27    }
28
29    pub const fn from_array(arr: [usize; N]) -> Self {
30        Self {
31            l1: Vec::from_array(arr),
32        }
33    }
34
35    pub fn alloc(&mut self, bit_count: usize) -> Option<usize> {
36        // If a bit is 1 the bit is free. If a bit is 0 the bit is allocated.
37        let mut start = 0;
38        let mut len = 0usize;
39
40        let rem = bit_count.saturating_sub(Self::BITS_PER_WORD);
41        let mask = (!0usize).unbounded_shl((Self::BITS_PER_WORD.saturating_sub(bit_count)) as u32);
42
43        for idx in 0..N {
44            if self.l1[idx] == 0 {
45                len = 0;
46                continue;
47            }
48
49            let mut byte = self.l1[idx];
50
51            let mut shift = if len > 0 {
52                0usize
53            } else {
54                byte.leading_zeros() as usize
55            };
56
57            byte <<= shift;
58
59            while shift < Self::BITS_PER_WORD {
60                // Make the mask smaller if we already have some contiguous bits.
61                let mask = if rem.saturating_sub(len) == 0 {
62                    mask << (len - rem)
63                } else {
64                    mask
65                };
66
67                // We shifted byte to MSB, mask is already aligned to the left.
68                // We compare them via and and shift to the right to shift out extra bits from the mask that would overflow into the next word.
69                let mut found = (byte & mask) >> shift;
70
71                // We also need to shift the mask to the right so that we can compare mask and found.
72                if found == (mask >> shift) {
73                    if len == 0 {
74                        start = idx * Self::BITS_PER_WORD + shift;
75                    }
76
77                    // Shift completely to the right.
78                    found >>= found.trailing_zeros();
79
80                    // As all found bits are now on the right we can just count them to get the amount we found.
81                    len += found.trailing_ones() as usize;
82                    // Continue to the next word if we haven't found enough bits yet.
83                    break;
84                } else {
85                    len = 0;
86                }
87
88                shift += 1;
89                byte <<= 1;
90            }
91
92            if len >= bit_count {
93                // Mark the allocated pages as used.
94                let mut idx = start / Self::BITS_PER_WORD;
95
96                // Mark all bits in the first word as used.
97                {
98                    let skip = start % Self::BITS_PER_WORD;
99                    let rem = (Self::BITS_PER_WORD - skip).min(len);
100
101                    self.l1[idx] &=
102                        !((!0usize).unbounded_shl((Self::BITS_PER_WORD - rem) as u32) >> skip);
103
104                    if len <= rem {
105                        return Some(start);
106                    }
107
108                    len -= rem;
109                    idx += 1;
110                }
111
112                // Mark all bits in the middle words as used.
113                {
114                    let mid_cnt = len / Self::BITS_PER_WORD;
115
116                    for i in 0..mid_cnt {
117                        self.l1[idx + i] = 0;
118                    }
119
120                    idx += mid_cnt;
121                }
122
123                // Mark the remaining bits in the last word as used.
124                // Guard against `len % BITS_PER_WORD == 0`, which means the run ended
125                // exactly on a word boundary — there is no trailing partial word, and
126                // unguarded `self.l1[idx]` would index one past the last word.
127                if len % Self::BITS_PER_WORD != 0 {
128                    self.l1[idx] &= !((!0usize)
129                        .unbounded_shl((Self::BITS_PER_WORD - (len % Self::BITS_PER_WORD)) as u32));
130                }
131                return Some(start);
132            }
133        }
134
135        None
136    }
137
138    pub fn free(&mut self, bit: usize, bit_count: usize) {
139        let mut idx = bit / Self::BITS_PER_WORD;
140        let mut bit_idx = bit % Self::BITS_PER_WORD;
141
142        // TODO: slow
143        for _ in 0..bit_count {
144            self.l1[idx] |= 1 << (Self::BITS_PER_WORD - 1 - bit_idx);
145
146            bit_idx += 1;
147
148            if bit_idx == Self::BITS_PER_WORD {
149                bit_idx = 0;
150                idx += 1;
151            }
152        }
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn lsb_no_underflow_works() {
162        let mut alloc = BitAlloc::<1>::new(1).unwrap();
163        // Only the LSB in word 0 is free
164        alloc.l1[0] = 1;
165        let result = alloc.alloc(1);
166
167        assert!(result.is_some());
168    }
169
170    #[test]
171    fn msb_no_underflow_works() {
172        let mut alloc = BitAlloc::<1>::new(1).unwrap();
173        // Only the MSB in word 0 is free
174        alloc.l1[0] = 1 << (BitAlloc::<1>::BITS_PER_WORD - 1);
175        let result = alloc.alloc(1);
176
177        assert!(result.is_some());
178    }
179
180    #[test]
181    fn alloc_full_two_words() {
182        let mut alloc = BitAlloc::<2>::new(2 * BitAlloc::<2>::BITS_PER_WORD).unwrap();
183        let r = alloc.alloc(2 * BitAlloc::<2>::BITS_PER_WORD);
184        assert_eq!(r, Some(0));
185        assert_eq!(alloc.l1[0], 0);
186        assert_eq!(alloc.l1[1], 0);
187    }
188
189    #[test]
190    fn alloc_full_three_words() {
191        let mut alloc = BitAlloc::<3>::new(3 * BitAlloc::<3>::BITS_PER_WORD).unwrap();
192        let r = alloc.alloc(3 * BitAlloc::<3>::BITS_PER_WORD);
193        assert_eq!(r, Some(0));
194        assert_eq!(alloc.l1[0], 0);
195        assert_eq!(alloc.l1[1], 0);
196        assert_eq!(alloc.l1[2], 0);
197    }
198
199    #[test]
200    fn test_random_pattern() {
201        const ITARATIONS: usize = 10000;
202
203        for _ in 0..ITARATIONS {
204            const N: usize = 1024;
205            const BITS: usize = BitAlloc::<N>::BITS_PER_WORD;
206
207            let alloc_size = rand::random::<usize>() % (N / 2) + 1;
208
209            let mut alloc = BitAlloc::<N>::new(N).unwrap();
210
211            // Generate a random bit pattern.
212            for i in 0..N {
213                let is_zero = rand::random::<bool>();
214
215                if is_zero {
216                    alloc.l1[i / BITS] &= !(1 << ((BITS - 1) - (i % BITS)));
217                }
218            }
219
220            // Place a run of alloc_size contiguous bits set to 1 at a random position.
221            let start = rand::random::<usize>() % (N - alloc_size);
222            for i in start..(start + alloc_size) {
223                alloc.l1[i / BITS] |= 1 << ((BITS - 1) - (i % BITS));
224            }
225
226            let pre = alloc.l1.clone();
227            let idx = alloc.alloc(alloc_size).expect("Failed to allocate bits");
228
229            // Check that the bits in returned indices is all ones in pre.
230            for i in 0..alloc_size {
231                let bit = (pre[(idx + i) / BITS] >> ((BITS - 1) - ((idx + i) % BITS))) & 1;
232                assert_eq!(bit, 1, "Bit at index {} is not set", idx + i);
233            }
234
235            // Check that the bits in returned indices is all zeros in allocator.l1.
236            for i in 0..alloc_size {
237                let bit = (alloc.l1[(idx + i) / BITS] >> ((BITS - 1) - ((idx + i) % BITS))) & 1;
238                assert_eq!(bit, 0, "Bit at index {} is not cleared", idx + i);
239            }
240
241            // Check that the bits in other indices are unchanged.
242            for i in 0..N {
243                if i >= idx && i < idx + alloc_size {
244                    continue;
245                }
246                let pre_bit = (pre[i / BITS] >> ((BITS - 1) - (i % BITS))) & 1;
247                let post_bit = (alloc.l1[i / BITS] >> ((BITS - 1) - (i % BITS))) & 1;
248                assert_eq!(pre_bit, post_bit, "Bit at index {} was modified", i);
249            }
250        }
251    }
252}