macros/
syscall.rs

1use proc_macro2::TokenStream;
2use quote::{ToTokens, format_ident};
3
4pub const MAX_ARGS: usize = 4;
5
6pub fn valid_ret_type_check(item: &syn::ItemFn) -> Result<proc_macro2::TokenStream, syn::Error> {
7    let ret_ty = match &item.sig.output {
8        syn::ReturnType::Default => {
9            // no "-> Type" present
10            return Err(syn::Error::new_spanned(
11                &item.sig.output,
12                "syscall_handler: missing return type; expected a register‐sized type",
13            ));
14        }
15        syn::ReturnType::Type(_, ty) => (*ty).clone(),
16    };
17
18    Ok(quote::quote! {
19        const _: () = {
20            if core::mem::size_of::<#ret_ty>() > core::mem::size_of::<usize>() {
21                panic!("syscall_handler: the return type is bigger than usize. return type must fit in a register.");
22            }
23        };
24    })
25}
26
27pub fn valid_arg_types_check(item: &syn::ItemFn) -> Result<Vec<syn::Type>, syn::Error> {
28    let types: Vec<Result<syn::Type, syn::Error>> = item
29        .sig
30        .inputs
31        .iter()
32        .map(|arg| {
33            if let syn::FnArg::Typed(pat_type) = arg {
34                Ok((*pat_type.ty).clone())
35            } else {
36                Err(syn::Error::new(
37                    item.sig.ident.span(),
38                    format!(
39                        "argument {} is invalid. expected a typed argument.\n",
40                        arg.to_token_stream()
41                    ),
42                ))
43            }
44        })
45        .collect();
46
47    let concat_errors: Vec<_> = types
48        .iter()
49        .filter_map(|arg0: &std::result::Result<syn::Type, syn::Error>| Result::err(arg0.clone()))
50        .collect();
51
52    if !concat_errors.is_empty() {
53        return Err(syn::Error::new(
54            item.sig.ident.span(),
55            format!(
56                "syscall_handler: function {} has invalid arguments: {}",
57                item.sig.ident,
58                concat_errors
59                    .iter()
60                    .map(|e| e.to_string())
61                    .collect::<Vec<_>>()
62                    .join(", ")
63            ),
64        ));
65    }
66
67    Ok(types.into_iter().map(Result::unwrap).collect())
68}
69
70pub fn syscall_handler_fn(item: &syn::ItemFn) -> TokenStream {
71    let name = item.sig.ident.to_string().to_uppercase();
72    let num_args = item.sig.inputs.len();
73
74    // Check if the function has a valid signature. So args <= 4 and return type is u32.
75    if num_args > MAX_ARGS {
76        return syn::Error::new(
77            item.sig.ident.span(),
78            format!("syscall_handler: function {name} has too many arguments (max is {MAX_ARGS})",),
79        )
80        .to_compile_error();
81    }
82
83    let ret_check = match valid_ret_type_check(item) {
84        Ok(check) => check,
85        Err(e) => return e.to_compile_error(),
86    };
87
88    let types = match valid_arg_types_check(item) {
89        Ok(types) => {
90            if types.len() > MAX_ARGS {
91                return syn::Error::new(
92                    item.sig.ident.span(),
93                    format!(
94                        "syscall_handler: function {name} has too many arguments (max is {MAX_ARGS})",
95                    ),
96                )
97                .to_compile_error();
98            }
99            types
100        }
101        Err(e) => return e.to_compile_error(),
102    };
103
104    // Check if each argument type is valid and fits in a register.
105    let size_checks: Vec<TokenStream> = types.iter().map(|ty| {
106        quote::quote! {
107            const _: () = {
108                if core::mem::size_of::<#ty>() > core::mem::size_of::<usize>() {
109                    panic!("syscall_handler: an argument type is bigger than usize. arguments must fit in a register.");
110                }
111            };
112        }
113    }).collect();
114
115    let unpack = types.iter().enumerate().map(|(i, ty)| {
116        quote::quote! {
117            unsafe { *(args.add(#i) as *const #ty) }
118        }
119    });
120
121    let wrapper_name = format_ident!("entry_{}", item.sig.ident.clone());
122    let func_name = item.sig.ident.clone();
123
124    let call = quote::quote! {
125        #func_name( #(#unpack),* )
126    };
127
128    let wrapper = quote::quote! {
129        #[unsafe(no_mangle)]
130        pub extern "C" fn  #wrapper_name(svc_args: *const core::ffi::c_uint) -> core::ffi::c_int {
131            // This function needs to extract the arguments from the pointer and call the original function by passing the arguments as actual different parameters.
132            let args = unsafe { svc_args as *const usize };
133            // Call the original function with the extracted arguments.
134            #call
135        }
136    };
137
138    quote::quote! {
139        #wrapper
140        #item
141        #ret_check
142        #(#size_checks)*
143    }
144}