Skip to main content

mlua/luau/
require.rs

1use std::cell::RefCell;
2use std::ffi::CStr;
3use std::io::Result as IoResult;
4use std::ops::{Deref, DerefMut};
5use std::os::raw::{c_char, c_int, c_void};
6use std::result::Result as StdResult;
7use std::{fmt, mem, ptr};
8
9use crate::error::{Error, Result};
10use crate::function::Function;
11use crate::state::{Lua, callback_error_ext};
12use crate::table::Table;
13use crate::types::MaybeSend;
14
15pub use fs::FsRequirer;
16
17/// An error that can occur during navigation in the Luau `require-by-string` system.
18#[derive(Debug, Clone)]
19pub enum NavigateError {
20    Ambiguous,
21    NotFound,
22    Other(Error),
23}
24
25#[cfg(feature = "luau")]
26trait IntoNavigateResult {
27    fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult>;
28}
29
30#[cfg(feature = "luau")]
31impl IntoNavigateResult for StdResult<(), NavigateError> {
32    fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult> {
33        match self {
34            Ok(()) => Ok(ffi::luarequire_NavigateResult::Success),
35            Err(NavigateError::Ambiguous) => Ok(ffi::luarequire_NavigateResult::Ambiguous),
36            Err(NavigateError::NotFound) => Ok(ffi::luarequire_NavigateResult::NotFound),
37            Err(NavigateError::Other(err)) => Err(err),
38        }
39    }
40}
41
42impl From<Error> for NavigateError {
43    fn from(err: Error) -> Self {
44        NavigateError::Other(err)
45    }
46}
47
48#[cfg(feature = "luau")]
49type WriteResult = ffi::luarequire_WriteResult;
50
51#[cfg(feature = "luau")]
52type ConfigStatus = ffi::luarequire_ConfigStatus;
53
54/// A trait for handling modules loading and navigation in the Luau `require-by-string` system.
55pub trait Require {
56    /// Returns `true` if "require" is permitted for the given chunk name.
57    fn is_require_allowed(&self, chunk_name: &str) -> bool;
58
59    /// Resets the internal state to point at the requirer module.
60    fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError>;
61
62    /// Resets the internal state to point at an aliased module.
63    ///
64    /// This function received an exact path from a configuration file.
65    /// It's only called when an alias's path cannot be resolved relative to its
66    /// configuration file.
67    fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError>;
68
69    // Navigate to parent directory
70    fn to_parent(&mut self) -> StdResult<(), NavigateError>;
71
72    /// Navigate to the given child directory.
73    fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError>;
74
75    /// Returns whether the context is currently pointing at a module.
76    fn has_module(&self) -> bool;
77
78    /// Provides a cache key representing the current module.
79    ///
80    /// This function is only called if `has_module` returns true.
81    fn cache_key(&self) -> String;
82
83    /// Returns whether a configuration is present in the current context.
84    fn has_config(&self) -> bool;
85
86    /// Returns the contents of the configuration file in the current context.
87    ///
88    /// This function is only called if `has_config` returns true.
89    fn config(&self) -> IoResult<Vec<u8>>;
90
91    /// Returns a loader function for the current module, that when called, loads the module
92    /// and returns the result.
93    ///
94    /// Loader can be sync or async.
95    /// This function is only called if `has_module` returns true.
96    fn loader(&self, lua: &Lua) -> Result<Function>;
97}
98
99impl fmt::Debug for dyn Require {
100    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101        write!(f, "<dyn Require>")
102    }
103}
104
105struct Context {
106    require: Box<dyn Require>,
107    config_cache: Option<IoResult<Vec<u8>>>,
108}
109
110impl Deref for Context {
111    type Target = dyn Require;
112
113    fn deref(&self) -> &Self::Target {
114        &*self.require
115    }
116}
117
118impl DerefMut for Context {
119    fn deref_mut(&mut self) -> &mut Self::Target {
120        &mut *self.require
121    }
122}
123
124impl Context {
125    fn new(require: impl Require + MaybeSend + 'static) -> Self {
126        Context {
127            require: Box::new(require),
128            config_cache: None,
129        }
130    }
131}
132
133macro_rules! try_borrow {
134    ($state:expr, $ctx:expr) => {
135        match (*($ctx as *const RefCell<Context>)).try_borrow() {
136            Ok(ctx) => ctx,
137            Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
138        }
139    };
140}
141
142macro_rules! try_borrow_mut {
143    ($state:expr, $ctx:expr) => {
144        match (*($ctx as *const RefCell<Context>)).try_borrow_mut() {
145            Ok(ctx) => ctx,
146            Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
147        }
148    };
149}
150
151#[cfg(feature = "luau")]
152pub(super) unsafe extern "C-unwind" fn init_config(config: *mut ffi::luarequire_Configuration) {
153    if config.is_null() {
154        return;
155    }
156
157    unsafe extern "C-unwind" fn is_require_allowed(
158        state: *mut ffi::lua_State,
159        ctx: *mut c_void,
160        requirer_chunkname: *const c_char,
161    ) -> bool {
162        if requirer_chunkname.is_null() {
163            return false;
164        }
165
166        let this = try_borrow!(state, ctx);
167        let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
168        this.is_require_allowed(&chunk_name)
169    }
170
171    unsafe extern "C-unwind" fn reset(
172        state: *mut ffi::lua_State,
173        ctx: *mut c_void,
174        requirer_chunkname: *const c_char,
175    ) -> ffi::luarequire_NavigateResult {
176        let mut this = try_borrow_mut!(state, ctx);
177        let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
178        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
179            this.reset(&chunk_name).into_nav_result()
180        })
181    }
182
183    unsafe extern "C-unwind" fn jump_to_alias(
184        state: *mut ffi::lua_State,
185        ctx: *mut c_void,
186        path: *const c_char,
187    ) -> ffi::luarequire_NavigateResult {
188        let mut this = try_borrow_mut!(state, ctx);
189        let path = CStr::from_ptr(path).to_string_lossy();
190        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
191            this.jump_to_alias(&path).into_nav_result()
192        })
193    }
194
195    unsafe extern "C-unwind" fn to_parent(
196        state: *mut ffi::lua_State,
197        ctx: *mut c_void,
198    ) -> ffi::luarequire_NavigateResult {
199        let mut this = try_borrow_mut!(state, ctx);
200        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
201            this.to_parent().into_nav_result()
202        })
203    }
204
205    unsafe extern "C-unwind" fn to_child(
206        state: *mut ffi::lua_State,
207        ctx: *mut c_void,
208        name: *const c_char,
209    ) -> ffi::luarequire_NavigateResult {
210        let mut this = try_borrow_mut!(state, ctx);
211        let name = CStr::from_ptr(name).to_string_lossy();
212        callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
213            this.to_child(&name).into_nav_result()
214        })
215    }
216
217    unsafe extern "C-unwind" fn is_module_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
218        let this = try_borrow!(state, ctx);
219        this.has_module()
220    }
221
222    unsafe extern "C-unwind" fn get_chunkname(
223        _state: *mut ffi::lua_State,
224        _ctx: *mut c_void,
225        buffer: *mut c_char,
226        buffer_size: usize,
227        size_out: *mut usize,
228    ) -> WriteResult {
229        write_to_buffer(buffer, buffer_size, size_out, &[])
230    }
231
232    unsafe extern "C-unwind" fn get_loadname(
233        _state: *mut ffi::lua_State,
234        _ctx: *mut c_void,
235        buffer: *mut c_char,
236        buffer_size: usize,
237        size_out: *mut usize,
238    ) -> WriteResult {
239        write_to_buffer(buffer, buffer_size, size_out, &[])
240    }
241
242    unsafe extern "C-unwind" fn get_cache_key(
243        state: *mut ffi::lua_State,
244        ctx: *mut c_void,
245        buffer: *mut c_char,
246        buffer_size: usize,
247        size_out: *mut usize,
248    ) -> WriteResult {
249        let this = try_borrow!(state, ctx);
250        let cache_key = this.cache_key();
251        write_to_buffer(buffer, buffer_size, size_out, cache_key.as_bytes())
252    }
253
254    unsafe extern "C-unwind" fn get_config_status(
255        state: *mut ffi::lua_State,
256        ctx: *mut c_void,
257    ) -> ConfigStatus {
258        let mut this = try_borrow_mut!(state, ctx);
259        if this.has_config() {
260            this.config_cache = Some(this.config());
261            if let Some(Ok(data)) = &this.config_cache {
262                return detect_config_format(data);
263            }
264        }
265        ConfigStatus::Absent
266    }
267
268    unsafe extern "C-unwind" fn get_config(
269        state: *mut ffi::lua_State,
270        ctx: *mut c_void,
271        buffer: *mut c_char,
272        buffer_size: usize,
273        size_out: *mut usize,
274    ) -> WriteResult {
275        let mut this = try_borrow_mut!(state, ctx);
276        let config = callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
277            Ok(this.config_cache.take().unwrap_or_else(|| this.config())?)
278        });
279        write_to_buffer(buffer, buffer_size, size_out, &config)
280    }
281
282    unsafe extern "C-unwind" fn load(
283        state: *mut ffi::lua_State,
284        ctx: *mut c_void,
285        _path: *const c_char,
286        _chunkname: *const c_char,
287        _loadname: *const c_char,
288    ) -> c_int {
289        let this = try_borrow!(state, ctx);
290        callback_error_ext(state, ptr::null_mut(), true, move |extra, _| {
291            let rawlua = (*extra).raw_lua();
292            let loader = this.loader(rawlua.lua())?;
293            rawlua.push(loader)?;
294            Ok(1)
295        })
296    }
297
298    (*config).is_require_allowed = is_require_allowed;
299    (*config).reset = reset;
300    (*config).jump_to_alias = jump_to_alias;
301    (*config).to_alias_override = None;
302    (*config).to_alias_fallback = None;
303    (*config).to_parent = to_parent;
304    (*config).to_child = to_child;
305    (*config).is_module_present = is_module_present;
306    (*config).get_chunkname = get_chunkname;
307    (*config).get_loadname = get_loadname;
308    (*config).get_cache_key = get_cache_key;
309    (*config).get_config_status = get_config_status;
310    (*config).get_alias = None;
311    (*config).get_config = Some(get_config);
312    (*config).load = load;
313}
314
315/// Detect configuration file format (JSON or Luau)
316#[cfg(feature = "luau")]
317fn detect_config_format(data: &[u8]) -> ConfigStatus {
318    let data = data.trim_ascii();
319    if data.starts_with(b"{") {
320        let data = &data[1..].trim_ascii_start();
321        if data.starts_with(b"\"") || data == b"}" {
322            return ConfigStatus::PresentJson;
323        }
324    }
325    ConfigStatus::PresentLuau
326}
327
328/// Helper function to write data to a buffer
329#[cfg(feature = "luau")]
330unsafe fn write_to_buffer(
331    buffer: *mut c_char,
332    buffer_size: usize,
333    size_out: *mut usize,
334    data: &[u8],
335) -> WriteResult {
336    // the buffer must be null terminated as it's a c++ `std::string` data() buffer
337    let is_null_terminated = data.last() == Some(&0);
338    *size_out = data.len() + if is_null_terminated { 0 } else { 1 };
339    if *size_out > buffer_size {
340        return WriteResult::BufferTooSmall;
341    }
342    ptr::copy_nonoverlapping(data.as_ptr(), buffer as *mut _, data.len());
343    if !is_null_terminated {
344        *buffer.add(data.len()) = 0;
345    }
346    WriteResult::Success
347}
348
349#[cfg(feature = "luau")]
350pub(super) fn create_require_function<R: Require + MaybeSend + 'static>(
351    lua: &Lua,
352    require: R,
353) -> Result<Function> {
354    unsafe extern "C-unwind" fn find_current_file(state: *mut ffi::lua_State) -> c_int {
355        let mut ar: ffi::lua_Debug = mem::zeroed();
356        for level in 2.. {
357            if ffi::lua_getinfo(state, level, cstr!("s"), &mut ar) == 0 {
358                ffi::luaL_error(state, cstr!("require is not supported in this context"));
359            }
360            if CStr::from_ptr(ar.what) != c"C" {
361                break;
362            }
363        }
364        ffi::lua_pushstring(state, ar.source);
365        1
366    }
367
368    unsafe extern "C-unwind" fn get_cache_key(state: *mut ffi::lua_State) -> c_int {
369        let ctx = ffi::lua_touserdata(state, ffi::lua_upvalueindex(1));
370        let ctx = try_borrow!(state, ctx);
371        let cache_key = ctx.cache_key();
372        ffi::lua_pushlstring(state, cache_key.as_ptr() as *const _, cache_key.len());
373        1
374    }
375
376    let (get_cache_key, find_current_file, proxyrequire, registered_modules, loader_cache) = unsafe {
377        lua.exec_raw::<(Function, Function, Function, Table, Table)>((), move |state| {
378            let context = Context::new(require);
379            let context_ptr = ffi::lua_newuserdata_t(state, RefCell::new(context));
380            ffi::lua_pushcclosured(state, get_cache_key, cstr!("get_cache_key"), 1);
381            ffi::lua_pushcfunctiond(state, find_current_file, cstr!("find_current_file"));
382            ffi::luarequire_pushproxyrequire(state, init_config, context_ptr as *mut _);
383            ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_REGISTERED_MODULES_TABLE);
384            ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("__MLUA_LOADER_CACHE"));
385        })
386    }?;
387
388    unsafe extern "C-unwind" fn error(state: *mut ffi::lua_State) -> c_int {
389        ffi::luaL_where(state, 1);
390        ffi::lua_pushvalue(state, 1);
391        ffi::lua_concat(state, 2);
392        ffi::lua_error(state);
393    }
394
395    unsafe extern "C-unwind" fn r#type(state: *mut ffi::lua_State) -> c_int {
396        ffi::lua_pushstring(state, ffi::lua_typename(state, ffi::lua_type(state, 1)));
397        1
398    }
399
400    unsafe extern "C-unwind" fn to_lowercase(state: *mut ffi::lua_State) -> c_int {
401        let s = ffi::luaL_checkstring(state, 1);
402        let s = CStr::from_ptr(s);
403        if !s.to_bytes().iter().any(|&c| c.is_ascii_uppercase()) {
404            // If the string does not contain any uppercase ASCII letters, return it as is
405            return 1;
406        }
407        callback_error_ext(state, ptr::null_mut(), true, |extra, _| {
408            let s = (s.to_bytes().iter())
409                .map(|&c| c.to_ascii_lowercase())
410                .collect::<bstr::BString>();
411            (*extra).raw_lua().push(s).map(|_| 1)
412        })
413    }
414
415    let (error, r#type, to_lowercase) = unsafe {
416        lua.exec_raw::<(Function, Function, Function)>((), move |state| {
417            ffi::lua_pushcfunctiond(state, error, cstr!("error"));
418            ffi::lua_pushcfunctiond(state, r#type, cstr!("type"));
419            ffi::lua_pushcfunctiond(state, to_lowercase, cstr!("to_lowercase"));
420        })
421    }?;
422
423    // Prepare environment for the "require" function
424    let env = lua.create_table_with_capacity(0, 7)?;
425    env.raw_set("get_cache_key", get_cache_key)?;
426    env.raw_set("find_current_file", find_current_file)?;
427    env.raw_set("proxyrequire", proxyrequire)?;
428    env.raw_set("REGISTERED_MODULES", registered_modules)?;
429    env.raw_set("LOADER_CACHE", loader_cache)?;
430    env.raw_set("error", error)?;
431    env.raw_set("type", r#type)?;
432    env.raw_set("to_lowercase", to_lowercase)?;
433
434    lua.load(
435        r#"
436        local path = ...
437        if type(path) ~= "string" then
438            error("bad argument #1 to 'require' (string expected, got " .. type(path) .. ")")
439        end
440
441        -- Check if the module (path) is explicitly registered
442        local maybe_result = REGISTERED_MODULES[to_lowercase(path)]
443        if maybe_result ~= nil then
444            return maybe_result
445        end
446
447        local loader = proxyrequire(path, find_current_file())
448        local cache_key = get_cache_key()
449        -- Check if the loader result is already cached
450        local result = LOADER_CACHE[cache_key]
451        if result ~= nil then
452            return result
453        end
454
455        -- Call the loader function and cache the result
456        result = loader()
457        if result == nil then
458            result = true
459        end
460        LOADER_CACHE[cache_key] = result
461        return result
462        "#,
463    )
464    .try_cache()
465    .set_name("=__mlua_require")
466    .set_environment(env)
467    .into_function()
468}
469
470mod fs;