osiris/types/
bitset.rs

1//! A simple bitset allocator that can be used to allocate contiguous runs of bits.
2
3use crate::types::array::Vec;
4
5pub struct BitAlloc<const N: usize> {
6    l1: Vec<usize, N>,
7}
8
9impl<const N: usize> BitAlloc<N> {
10    pub const BITS_PER_WORD: usize = usize::BITS as usize;
11
12    pub fn new(free_count: usize) -> Option<Self> {
13        let mut l1 = Vec::new();
14        let words = free_count.div_ceil(Self::BITS_PER_WORD);
15
16        for i in 0..words {
17            let rem = free_count.saturating_sub(i * Self::BITS_PER_WORD);
18            if rem >= Self::BITS_PER_WORD {
19                l1.push(!0usize).ok()?;
20            } else {
21                l1.push((!0usize).unbounded_shl((Self::BITS_PER_WORD - rem) as u32))
22                    .ok()?;
23            }
24        }
25
26        Some(Self { l1 })
27    }
28
29    pub const fn from_array(arr: [usize; N]) -> Self {
30        Self {
31            l1: Vec::from_array(arr),
32        }
33    }
34
35    pub fn alloc(&mut self, bit_count: usize) -> Option<usize> {
36        // If a bit is 1 the bit is free. If a bit is 0 the bit is allocated.
37        let mut start = 0;
38        let mut len = 0usize;
39
40        let rem = bit_count.saturating_sub(Self::BITS_PER_WORD);
41        let mask = (!0usize).unbounded_shl((Self::BITS_PER_WORD.saturating_sub(bit_count)) as u32);
42
43        for idx in 0..N {
44            if self.l1[idx] == 0 {
45                len = 0;
46                continue;
47            }
48
49            let mut byte = self.l1[idx];
50
51            let mut shift = if len > 0 {
52                0usize
53            } else {
54                byte.leading_zeros() as usize
55            };
56
57            byte <<= shift;
58
59            while shift < Self::BITS_PER_WORD {
60                // Make the mask smaller if we already have some contiguous bits.
61                let mask = if rem.saturating_sub(len) == 0 {
62                    mask << (len - rem)
63                } else {
64                    mask
65                };
66
67                // We shifted byte to MSB, mask is already aligned to the left.
68                // We compare them via and and shift to the right to shift out extra bits from the mask that would overflow into the next word.
69                let mut found = (byte & mask) >> shift;
70
71                // We also need to shift the mask to the right so that we can compare mask and found.
72                if found == (mask >> shift) {
73                    if len == 0 {
74                        start = idx * Self::BITS_PER_WORD + shift;
75                    }
76
77                    // Shift completely to the right.
78                    found >>= found.trailing_zeros();
79
80                    // As all found bits are now on the right we can just count them to get the amount we found.
81                    len += found.trailing_ones() as usize;
82                    // Continue to the next word if we haven't found enough bits yet.
83                    break;
84                } else {
85                    len = 0;
86                }
87
88                shift += 1;
89                byte <<= 1;
90            }
91
92            if len >= bit_count {
93                // Mark the allocated pages as used.
94                let mut idx = start / Self::BITS_PER_WORD;
95
96                // Mark all bits in the first word as used.
97                {
98                    let skip = start % Self::BITS_PER_WORD;
99                    let rem = (Self::BITS_PER_WORD - skip).min(len);
100
101                    self.l1[idx] &=
102                        !((!0usize).unbounded_shl((Self::BITS_PER_WORD - rem) as u32) >> skip);
103
104                    if len <= rem {
105                        return Some(start);
106                    }
107
108                    len -= rem;
109                    idx += 1;
110                }
111
112                // Mark all bits in the middle words as used.
113                {
114                    let mid_cnt = len / Self::BITS_PER_WORD;
115
116                    for i in 0..mid_cnt {
117                        self.l1[idx + i] = 0;
118                    }
119
120                    idx += mid_cnt;
121                }
122
123                // Mark the remaining bits in the last word as used.
124                self.l1[idx] &= !((!0usize)
125                    .unbounded_shl((Self::BITS_PER_WORD - (len % Self::BITS_PER_WORD)) as u32));
126                return Some(start);
127            }
128        }
129
130        None
131    }
132
133    pub fn free(&mut self, bit: usize, bit_count: usize) {
134        let mut idx = bit / Self::BITS_PER_WORD;
135        let mut bit_idx = bit % Self::BITS_PER_WORD;
136
137        // TODO: slow
138        for _ in 0..bit_count {
139            self.l1[idx] |= 1 << (Self::BITS_PER_WORD - 1 - bit_idx);
140
141            bit_idx += 1;
142
143            if bit_idx == Self::BITS_PER_WORD {
144                bit_idx = 0;
145                idx += 1;
146            }
147        }
148    }
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    #[test]
156    fn lsb_no_underflow_works() {
157        let mut alloc = BitAlloc::<1>::new(1).unwrap();
158        // Only the LSB in word 0 is free
159        alloc.l1[0] = 1;
160        let result = alloc.alloc(1);
161
162        assert!(result.is_some());
163    }
164
165    #[test]
166    fn msb_no_underflow_works() {
167        let mut alloc = BitAlloc::<1>::new(1).unwrap();
168        // Only the MSB in word 0 is free
169        alloc.l1[0] = 1 << (BitAlloc::<1>::BITS_PER_WORD - 1);
170        let result = alloc.alloc(1);
171
172        assert!(result.is_some());
173    }
174
175    #[test]
176    fn test_random_pattern() {
177        const ITARATIONS: usize = 10000;
178
179        for _ in 0..ITARATIONS {
180            const N: usize = 1024;
181            const BITS: usize = BitAlloc::<N>::BITS_PER_WORD;
182
183            let alloc_size = rand::random::<usize>() % (N / 2) + 1;
184
185            let mut alloc = BitAlloc::<N>::new(N).unwrap();
186
187            // Generate a random bit pattern.
188            for i in 0..N {
189                let is_zero = rand::random::<bool>();
190
191                if is_zero {
192                    alloc.l1[i / BITS] &= !(1 << ((BITS - 1) - (i % BITS)));
193                }
194            }
195
196            // Place a run of alloc_size contiguous bits set to 1 at a random position.
197            let start = rand::random::<usize>() % (N - alloc_size);
198            for i in start..(start + alloc_size) {
199                alloc.l1[i / BITS] |= 1 << ((BITS - 1) - (i % BITS));
200            }
201
202            let pre = alloc.l1.clone();
203            let idx = alloc.alloc(alloc_size).expect("Failed to allocate bits");
204
205            // Check that the bits in returned indices is all ones in pre.
206            for i in 0..alloc_size {
207                let bit = (pre[(idx + i) / BITS] >> ((BITS - 1) - ((idx + i) % BITS))) & 1;
208                assert_eq!(bit, 1, "Bit at index {} is not set", idx + i);
209            }
210
211            // Check that the bits in returned indices is all zeros in allocator.l1.
212            for i in 0..alloc_size {
213                let bit = (alloc.l1[(idx + i) / BITS] >> ((BITS - 1) - ((idx + i) % BITS))) & 1;
214                assert_eq!(bit, 0, "Bit at index {} is not cleared", idx + i);
215            }
216
217            // Check that the bits in other indices are unchanged.
218            for i in 0..N {
219                if i >= idx && i < idx + alloc_size {
220                    continue;
221                }
222                let pre_bit = (pre[i / BITS] >> ((BITS - 1) - (i % BITS))) & 1;
223                let post_bit = (alloc.l1[i / BITS] >> ((BITS - 1) - (i % BITS))) & 1;
224                assert_eq!(pre_bit, post_bit, "Bit at index {} was modified", i);
225            }
226        }
227    }
228}