1use crate::types::array::Vec;
4
5pub struct BitAlloc<const N: usize> {
6 l1: Vec<usize, N>,
7}
8
9impl<const N: usize> BitAlloc<N> {
10 pub const BITS_PER_WORD: usize = usize::BITS as usize;
11
12 pub fn new(free_count: usize) -> Option<Self> {
13 let mut l1 = Vec::new();
14 let words = free_count.div_ceil(Self::BITS_PER_WORD);
15
16 for i in 0..words {
17 let rem = free_count.saturating_sub(i * Self::BITS_PER_WORD);
18 if rem >= Self::BITS_PER_WORD {
19 l1.push(!0usize).ok()?;
20 } else {
21 l1.push((!0usize).unbounded_shl((Self::BITS_PER_WORD - rem) as u32))
22 .ok()?;
23 }
24 }
25
26 Some(Self { l1 })
27 }
28
29 pub const fn from_array(arr: [usize; N]) -> Self {
30 Self {
31 l1: Vec::from_array(arr),
32 }
33 }
34
35 pub fn alloc(&mut self, bit_count: usize) -> Option<usize> {
36 let mut start = 0;
38 let mut len = 0usize;
39
40 let rem = bit_count.saturating_sub(Self::BITS_PER_WORD);
41 let mask = (!0usize).unbounded_shl((Self::BITS_PER_WORD.saturating_sub(bit_count)) as u32);
42
43 for idx in 0..N {
44 if self.l1[idx] == 0 {
45 len = 0;
46 continue;
47 }
48
49 let mut byte = self.l1[idx];
50
51 let mut shift = if len > 0 {
52 0usize
53 } else {
54 byte.leading_zeros() as usize
55 };
56
57 byte <<= shift;
58
59 while shift < Self::BITS_PER_WORD {
60 let mask = if rem.saturating_sub(len) == 0 {
62 mask << (len - rem)
63 } else {
64 mask
65 };
66
67 let mut found = (byte & mask) >> shift;
70
71 if found == (mask >> shift) {
73 if len == 0 {
74 start = idx * Self::BITS_PER_WORD + shift;
75 }
76
77 found >>= found.trailing_zeros();
79
80 len += found.trailing_ones() as usize;
82 break;
84 } else {
85 len = 0;
86 }
87
88 shift += 1;
89 byte <<= 1;
90 }
91
92 if len >= bit_count {
93 let mut idx = start / Self::BITS_PER_WORD;
95
96 {
98 let skip = start % Self::BITS_PER_WORD;
99 let rem = (Self::BITS_PER_WORD - skip).min(len);
100
101 self.l1[idx] &=
102 !((!0usize).unbounded_shl((Self::BITS_PER_WORD - rem) as u32) >> skip);
103
104 if len <= rem {
105 return Some(start);
106 }
107
108 len -= rem;
109 idx += 1;
110 }
111
112 {
114 let mid_cnt = len / Self::BITS_PER_WORD;
115
116 for i in 0..mid_cnt {
117 self.l1[idx + i] = 0;
118 }
119
120 idx += mid_cnt;
121 }
122
123 self.l1[idx] &= !((!0usize)
125 .unbounded_shl((Self::BITS_PER_WORD - (len % Self::BITS_PER_WORD)) as u32));
126 return Some(start);
127 }
128 }
129
130 None
131 }
132
133 pub fn free(&mut self, bit: usize, bit_count: usize) {
134 let mut idx = bit / Self::BITS_PER_WORD;
135 let mut bit_idx = bit % Self::BITS_PER_WORD;
136
137 for _ in 0..bit_count {
139 self.l1[idx] |= 1 << (Self::BITS_PER_WORD - 1 - bit_idx);
140
141 bit_idx += 1;
142
143 if bit_idx == Self::BITS_PER_WORD {
144 bit_idx = 0;
145 idx += 1;
146 }
147 }
148 }
149}
150
151#[cfg(test)]
152mod tests {
153 use super::*;
154
155 #[test]
156 fn lsb_no_underflow_works() {
157 let mut alloc = BitAlloc::<1>::new(1).unwrap();
158 alloc.l1[0] = 1;
160 let result = alloc.alloc(1);
161
162 assert!(result.is_some());
163 }
164
165 #[test]
166 fn msb_no_underflow_works() {
167 let mut alloc = BitAlloc::<1>::new(1).unwrap();
168 alloc.l1[0] = 1 << (BitAlloc::<1>::BITS_PER_WORD - 1);
170 let result = alloc.alloc(1);
171
172 assert!(result.is_some());
173 }
174
175 #[test]
176 fn test_random_pattern() {
177 const ITARATIONS: usize = 10000;
178
179 for _ in 0..ITARATIONS {
180 const N: usize = 1024;
181 const BITS: usize = BitAlloc::<N>::BITS_PER_WORD;
182
183 let alloc_size = rand::random::<usize>() % (N / 2) + 1;
184
185 let mut alloc = BitAlloc::<N>::new(N).unwrap();
186
187 for i in 0..N {
189 let is_zero = rand::random::<bool>();
190
191 if is_zero {
192 alloc.l1[i / BITS] &= !(1 << ((BITS - 1) - (i % BITS)));
193 }
194 }
195
196 let start = rand::random::<usize>() % (N - alloc_size);
198 for i in start..(start + alloc_size) {
199 alloc.l1[i / BITS] |= 1 << ((BITS - 1) - (i % BITS));
200 }
201
202 let pre = alloc.l1.clone();
203 let idx = alloc.alloc(alloc_size).expect("Failed to allocate bits");
204
205 for i in 0..alloc_size {
207 let bit = (pre[(idx + i) / BITS] >> ((BITS - 1) - ((idx + i) % BITS))) & 1;
208 assert_eq!(bit, 1, "Bit at index {} is not set", idx + i);
209 }
210
211 for i in 0..alloc_size {
213 let bit = (alloc.l1[(idx + i) / BITS] >> ((BITS - 1) - ((idx + i) % BITS))) & 1;
214 assert_eq!(bit, 0, "Bit at index {} is not cleared", idx + i);
215 }
216
217 for i in 0..N {
219 if i >= idx && i < idx + alloc_size {
220 continue;
221 }
222 let pre_bit = (pre[i / BITS] >> ((BITS - 1) - (i % BITS))) & 1;
223 let post_bit = (alloc.l1[i / BITS] >> ((BITS - 1) - (i % BITS))) & 1;
224 assert_eq!(pre_bit, post_bit, "Bit at index {} was modified", i);
225 }
226 }
227 }
228}