1use std::cell::RefCell;
2use std::ffi::CStr;
3use std::io::Result as IoResult;
4use std::ops::{Deref, DerefMut};
5use std::os::raw::{c_char, c_int, c_void};
6use std::result::Result as StdResult;
7use std::{fmt, mem, ptr};
8
9use crate::error::{Error, Result};
10use crate::function::Function;
11use crate::state::{Lua, callback_error_ext};
12use crate::table::Table;
13use crate::types::MaybeSend;
14
15pub use fs::FsRequirer;
16
17#[derive(Debug, Clone)]
19pub enum NavigateError {
20 Ambiguous,
21 NotFound,
22 Other(Error),
23}
24
25#[cfg(feature = "luau")]
26trait IntoNavigateResult {
27 fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult>;
28}
29
30#[cfg(feature = "luau")]
31impl IntoNavigateResult for StdResult<(), NavigateError> {
32 fn into_nav_result(self) -> Result<ffi::luarequire_NavigateResult> {
33 match self {
34 Ok(()) => Ok(ffi::luarequire_NavigateResult::Success),
35 Err(NavigateError::Ambiguous) => Ok(ffi::luarequire_NavigateResult::Ambiguous),
36 Err(NavigateError::NotFound) => Ok(ffi::luarequire_NavigateResult::NotFound),
37 Err(NavigateError::Other(err)) => Err(err),
38 }
39 }
40}
41
42impl From<Error> for NavigateError {
43 fn from(err: Error) -> Self {
44 NavigateError::Other(err)
45 }
46}
47
48#[cfg(feature = "luau")]
49type WriteResult = ffi::luarequire_WriteResult;
50
51#[cfg(feature = "luau")]
52type ConfigStatus = ffi::luarequire_ConfigStatus;
53
54pub trait Require {
56 fn is_require_allowed(&self, chunk_name: &str) -> bool;
58
59 fn reset(&mut self, chunk_name: &str) -> StdResult<(), NavigateError>;
61
62 fn jump_to_alias(&mut self, path: &str) -> StdResult<(), NavigateError>;
68
69 fn to_parent(&mut self) -> StdResult<(), NavigateError>;
71
72 fn to_child(&mut self, name: &str) -> StdResult<(), NavigateError>;
74
75 fn has_module(&self) -> bool;
77
78 fn cache_key(&self) -> String;
82
83 fn has_config(&self) -> bool;
85
86 fn config(&self) -> IoResult<Vec<u8>>;
90
91 fn loader(&self, lua: &Lua) -> Result<Function>;
97}
98
99impl fmt::Debug for dyn Require {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 write!(f, "<dyn Require>")
102 }
103}
104
105struct Context {
106 require: Box<dyn Require>,
107 config_cache: Option<IoResult<Vec<u8>>>,
108}
109
110impl Deref for Context {
111 type Target = dyn Require;
112
113 fn deref(&self) -> &Self::Target {
114 &*self.require
115 }
116}
117
118impl DerefMut for Context {
119 fn deref_mut(&mut self) -> &mut Self::Target {
120 &mut *self.require
121 }
122}
123
124impl Context {
125 fn new(require: impl Require + MaybeSend + 'static) -> Self {
126 Context {
127 require: Box::new(require),
128 config_cache: None,
129 }
130 }
131}
132
133macro_rules! try_borrow {
134 ($state:expr, $ctx:expr) => {
135 match (*($ctx as *const RefCell<Context>)).try_borrow() {
136 Ok(ctx) => ctx,
137 Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
138 }
139 };
140}
141
142macro_rules! try_borrow_mut {
143 ($state:expr, $ctx:expr) => {
144 match (*($ctx as *const RefCell<Context>)).try_borrow_mut() {
145 Ok(ctx) => ctx,
146 Err(_) => ffi::luaL_error($state, cstr!("require context is already borrowed")),
147 }
148 };
149}
150
151#[cfg(feature = "luau")]
152pub(super) unsafe extern "C-unwind" fn init_config(config: *mut ffi::luarequire_Configuration) {
153 if config.is_null() {
154 return;
155 }
156
157 unsafe extern "C-unwind" fn is_require_allowed(
158 state: *mut ffi::lua_State,
159 ctx: *mut c_void,
160 requirer_chunkname: *const c_char,
161 ) -> bool {
162 if requirer_chunkname.is_null() {
163 return false;
164 }
165
166 let this = try_borrow!(state, ctx);
167 let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
168 this.is_require_allowed(&chunk_name)
169 }
170
171 unsafe extern "C-unwind" fn reset(
172 state: *mut ffi::lua_State,
173 ctx: *mut c_void,
174 requirer_chunkname: *const c_char,
175 ) -> ffi::luarequire_NavigateResult {
176 let mut this = try_borrow_mut!(state, ctx);
177 let chunk_name = CStr::from_ptr(requirer_chunkname).to_string_lossy();
178 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
179 this.reset(&chunk_name).into_nav_result()
180 })
181 }
182
183 unsafe extern "C-unwind" fn jump_to_alias(
184 state: *mut ffi::lua_State,
185 ctx: *mut c_void,
186 path: *const c_char,
187 ) -> ffi::luarequire_NavigateResult {
188 let mut this = try_borrow_mut!(state, ctx);
189 let path = CStr::from_ptr(path).to_string_lossy();
190 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
191 this.jump_to_alias(&path).into_nav_result()
192 })
193 }
194
195 unsafe extern "C-unwind" fn to_parent(
196 state: *mut ffi::lua_State,
197 ctx: *mut c_void,
198 ) -> ffi::luarequire_NavigateResult {
199 let mut this = try_borrow_mut!(state, ctx);
200 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
201 this.to_parent().into_nav_result()
202 })
203 }
204
205 unsafe extern "C-unwind" fn to_child(
206 state: *mut ffi::lua_State,
207 ctx: *mut c_void,
208 name: *const c_char,
209 ) -> ffi::luarequire_NavigateResult {
210 let mut this = try_borrow_mut!(state, ctx);
211 let name = CStr::from_ptr(name).to_string_lossy();
212 callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
213 this.to_child(&name).into_nav_result()
214 })
215 }
216
217 unsafe extern "C-unwind" fn is_module_present(state: *mut ffi::lua_State, ctx: *mut c_void) -> bool {
218 let this = try_borrow!(state, ctx);
219 this.has_module()
220 }
221
222 unsafe extern "C-unwind" fn get_chunkname(
223 _state: *mut ffi::lua_State,
224 _ctx: *mut c_void,
225 buffer: *mut c_char,
226 buffer_size: usize,
227 size_out: *mut usize,
228 ) -> WriteResult {
229 write_to_buffer(buffer, buffer_size, size_out, &[])
230 }
231
232 unsafe extern "C-unwind" fn get_loadname(
233 _state: *mut ffi::lua_State,
234 _ctx: *mut c_void,
235 buffer: *mut c_char,
236 buffer_size: usize,
237 size_out: *mut usize,
238 ) -> WriteResult {
239 write_to_buffer(buffer, buffer_size, size_out, &[])
240 }
241
242 unsafe extern "C-unwind" fn get_cache_key(
243 state: *mut ffi::lua_State,
244 ctx: *mut c_void,
245 buffer: *mut c_char,
246 buffer_size: usize,
247 size_out: *mut usize,
248 ) -> WriteResult {
249 let this = try_borrow!(state, ctx);
250 let cache_key = this.cache_key();
251 write_to_buffer(buffer, buffer_size, size_out, cache_key.as_bytes())
252 }
253
254 unsafe extern "C-unwind" fn get_config_status(
255 state: *mut ffi::lua_State,
256 ctx: *mut c_void,
257 ) -> ConfigStatus {
258 let mut this = try_borrow_mut!(state, ctx);
259 if this.has_config() {
260 this.config_cache = Some(this.config());
261 if let Some(Ok(data)) = &this.config_cache {
262 return detect_config_format(data);
263 }
264 }
265 ConfigStatus::Absent
266 }
267
268 unsafe extern "C-unwind" fn get_config(
269 state: *mut ffi::lua_State,
270 ctx: *mut c_void,
271 buffer: *mut c_char,
272 buffer_size: usize,
273 size_out: *mut usize,
274 ) -> WriteResult {
275 let mut this = try_borrow_mut!(state, ctx);
276 let config = callback_error_ext(state, ptr::null_mut(), true, move |_, _| {
277 Ok(this.config_cache.take().unwrap_or_else(|| this.config())?)
278 });
279 write_to_buffer(buffer, buffer_size, size_out, &config)
280 }
281
282 unsafe extern "C-unwind" fn load(
283 state: *mut ffi::lua_State,
284 ctx: *mut c_void,
285 _path: *const c_char,
286 _chunkname: *const c_char,
287 _loadname: *const c_char,
288 ) -> c_int {
289 let this = try_borrow!(state, ctx);
290 callback_error_ext(state, ptr::null_mut(), true, move |extra, _| {
291 let rawlua = (*extra).raw_lua();
292 let loader = this.loader(rawlua.lua())?;
293 rawlua.push(loader)?;
294 Ok(1)
295 })
296 }
297
298 (*config).is_require_allowed = is_require_allowed;
299 (*config).reset = reset;
300 (*config).jump_to_alias = jump_to_alias;
301 (*config).to_alias_override = None;
302 (*config).to_alias_fallback = None;
303 (*config).to_parent = to_parent;
304 (*config).to_child = to_child;
305 (*config).is_module_present = is_module_present;
306 (*config).get_chunkname = get_chunkname;
307 (*config).get_loadname = get_loadname;
308 (*config).get_cache_key = get_cache_key;
309 (*config).get_config_status = get_config_status;
310 (*config).get_alias = None;
311 (*config).get_config = Some(get_config);
312 (*config).load = load;
313}
314
315#[cfg(feature = "luau")]
317fn detect_config_format(data: &[u8]) -> ConfigStatus {
318 let data = data.trim_ascii();
319 if data.starts_with(b"{") {
320 let data = &data[1..].trim_ascii_start();
321 if data.starts_with(b"\"") || data == b"}" {
322 return ConfigStatus::PresentJson;
323 }
324 }
325 ConfigStatus::PresentLuau
326}
327
328#[cfg(feature = "luau")]
330unsafe fn write_to_buffer(
331 buffer: *mut c_char,
332 buffer_size: usize,
333 size_out: *mut usize,
334 data: &[u8],
335) -> WriteResult {
336 let is_null_terminated = data.last() == Some(&0);
338 *size_out = data.len() + if is_null_terminated { 0 } else { 1 };
339 if *size_out > buffer_size {
340 return WriteResult::BufferTooSmall;
341 }
342 ptr::copy_nonoverlapping(data.as_ptr(), buffer as *mut _, data.len());
343 if !is_null_terminated {
344 *buffer.add(data.len()) = 0;
345 }
346 WriteResult::Success
347}
348
349#[cfg(feature = "luau")]
350pub(super) fn create_require_function<R: Require + MaybeSend + 'static>(
351 lua: &Lua,
352 require: R,
353) -> Result<Function> {
354 unsafe extern "C-unwind" fn find_current_file(state: *mut ffi::lua_State) -> c_int {
355 let mut ar: ffi::lua_Debug = mem::zeroed();
356 for level in 2.. {
357 if ffi::lua_getinfo(state, level, cstr!("s"), &mut ar) == 0 {
358 ffi::luaL_error(state, cstr!("require is not supported in this context"));
359 }
360 if CStr::from_ptr(ar.what) != c"C" {
361 break;
362 }
363 }
364 ffi::lua_pushstring(state, ar.source);
365 1
366 }
367
368 unsafe extern "C-unwind" fn get_cache_key(state: *mut ffi::lua_State) -> c_int {
369 let ctx = ffi::lua_touserdata(state, ffi::lua_upvalueindex(1));
370 let ctx = try_borrow!(state, ctx);
371 let cache_key = ctx.cache_key();
372 ffi::lua_pushlstring(state, cache_key.as_ptr() as *const _, cache_key.len());
373 1
374 }
375
376 let (get_cache_key, find_current_file, proxyrequire, registered_modules, loader_cache) = unsafe {
377 lua.exec_raw::<(Function, Function, Function, Table, Table)>((), move |state| {
378 let context = Context::new(require);
379 let context_ptr = ffi::lua_newuserdata_t(state, RefCell::new(context));
380 ffi::lua_pushcclosured(state, get_cache_key, cstr!("get_cache_key"), 1);
381 ffi::lua_pushcfunctiond(state, find_current_file, cstr!("find_current_file"));
382 ffi::luarequire_pushproxyrequire(state, init_config, context_ptr as *mut _);
383 ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, ffi::LUA_REGISTERED_MODULES_TABLE);
384 ffi::luaL_getsubtable(state, ffi::LUA_REGISTRYINDEX, cstr!("__MLUA_LOADER_CACHE"));
385 })
386 }?;
387
388 unsafe extern "C-unwind" fn error(state: *mut ffi::lua_State) -> c_int {
389 ffi::luaL_where(state, 1);
390 ffi::lua_pushvalue(state, 1);
391 ffi::lua_concat(state, 2);
392 ffi::lua_error(state);
393 }
394
395 unsafe extern "C-unwind" fn r#type(state: *mut ffi::lua_State) -> c_int {
396 ffi::lua_pushstring(state, ffi::lua_typename(state, ffi::lua_type(state, 1)));
397 1
398 }
399
400 unsafe extern "C-unwind" fn to_lowercase(state: *mut ffi::lua_State) -> c_int {
401 let s = ffi::luaL_checkstring(state, 1);
402 let s = CStr::from_ptr(s);
403 if !s.to_bytes().iter().any(|&c| c.is_ascii_uppercase()) {
404 return 1;
406 }
407 callback_error_ext(state, ptr::null_mut(), true, |extra, _| {
408 let s = (s.to_bytes().iter())
409 .map(|&c| c.to_ascii_lowercase())
410 .collect::<bstr::BString>();
411 (*extra).raw_lua().push(s).map(|_| 1)
412 })
413 }
414
415 let (error, r#type, to_lowercase) = unsafe {
416 lua.exec_raw::<(Function, Function, Function)>((), move |state| {
417 ffi::lua_pushcfunctiond(state, error, cstr!("error"));
418 ffi::lua_pushcfunctiond(state, r#type, cstr!("type"));
419 ffi::lua_pushcfunctiond(state, to_lowercase, cstr!("to_lowercase"));
420 })
421 }?;
422
423 let env = lua.create_table_with_capacity(0, 7)?;
425 env.raw_set("get_cache_key", get_cache_key)?;
426 env.raw_set("find_current_file", find_current_file)?;
427 env.raw_set("proxyrequire", proxyrequire)?;
428 env.raw_set("REGISTERED_MODULES", registered_modules)?;
429 env.raw_set("LOADER_CACHE", loader_cache)?;
430 env.raw_set("error", error)?;
431 env.raw_set("type", r#type)?;
432 env.raw_set("to_lowercase", to_lowercase)?;
433
434 lua.load(
435 r#"
436 local path = ...
437 if type(path) ~= "string" then
438 error("bad argument #1 to 'require' (string expected, got " .. type(path) .. ")")
439 end
440
441 -- Check if the module (path) is explicitly registered
442 local maybe_result = REGISTERED_MODULES[to_lowercase(path)]
443 if maybe_result ~= nil then
444 return maybe_result
445 end
446
447 local loader = proxyrequire(path, find_current_file())
448 local cache_key = get_cache_key()
449 -- Check if the loader result is already cached
450 local result = LOADER_CACHE[cache_key]
451 if result ~= nil then
452 return result
453 end
454
455 -- Call the loader function and cache the result
456 result = loader()
457 if result == nil then
458 result = true
459 end
460 LOADER_CACHE[cache_key] = result
461 return result
462 "#,
463 )
464 .try_cache()
465 .set_name("=__mlua_require")
466 .set_environment(env)
467 .into_function()
468}
469
470mod fs;