impatient.lua 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. -- modified version from https://github.com/lewis6991/impatient.nvim
  2. local vim = vim
  3. local api = vim.api
  4. local uv = vim.loop
  5. local _loadfile = loadfile
  6. local get_runtime = api.nvim__get_runtime
  7. local fs_stat = uv.fs_stat
  8. local mpack = vim.mpack
  9. local loadlib = package.loadlib
  10. local std_cache = vim.fn.stdpath "cache"
  11. local sep
  12. if jit.os == "Windows" then
  13. sep = "\\"
  14. else
  15. sep = "/"
  16. end
  17. local std_dirs = {
  18. ["<APPDIR>"] = os.getenv "APPDIR",
  19. ["<VIMRUNTIME>"] = os.getenv "VIMRUNTIME",
  20. ["<STD_DATA>"] = vim.fn.stdpath "data",
  21. ["<STD_CONFIG>"] = vim.fn.stdpath "config",
  22. ["<LVIM_BASE>"] = get_lvim_base_dir(),
  23. ["<LVIM_RUNTIME>"] = get_runtime_dir(),
  24. ["<LVIM_CONFIG>"] = get_config_dir(),
  25. }
  26. local function modpath_mangle(modpath)
  27. for name, dir in pairs(std_dirs) do
  28. modpath = modpath:gsub(dir, name)
  29. end
  30. return modpath
  31. end
  32. local function modpath_unmangle(modpath)
  33. for name, dir in pairs(std_dirs) do
  34. modpath = modpath:gsub(name, dir)
  35. end
  36. return modpath
  37. end
  38. -- Overridable by user
  39. local default_config = {
  40. chunks = {
  41. enable = true,
  42. path = std_cache .. sep .. "luacache_chunks",
  43. },
  44. modpaths = {
  45. enable = true,
  46. path = std_cache .. sep .. "luacache_modpaths",
  47. },
  48. }
  49. -- State used internally
  50. local default_state = {
  51. chunks = {
  52. cache = {},
  53. profile = nil,
  54. dirty = false,
  55. get = function(self, path)
  56. return self.cache[modpath_mangle(path)]
  57. end,
  58. set = function(self, path, chunk)
  59. self.cache[modpath_mangle(path)] = chunk
  60. end,
  61. },
  62. modpaths = {
  63. cache = {},
  64. profile = nil,
  65. dirty = false,
  66. get = function(self, mod)
  67. if self.cache[mod] then
  68. return modpath_unmangle(self.cache[mod])
  69. end
  70. end,
  71. set = function(self, mod, path)
  72. self.cache[mod] = modpath_mangle(path)
  73. end,
  74. },
  75. log = {},
  76. }
  77. ---@diagnostic disable-next-line: undefined-field
  78. local M = vim.tbl_deep_extend("keep", _G.__luacache_config or {}, default_config, default_state)
  79. _G.__luacache = M
  80. local function log(...)
  81. M.log[#M.log + 1] = table.concat({ string.format(...) }, " ")
  82. end
  83. local function print_log()
  84. for _, l in ipairs(M.log) do
  85. print(l)
  86. end
  87. end
  88. local function hash(modpath)
  89. local stat = fs_stat(modpath)
  90. if stat then
  91. return stat.mtime.sec .. stat.mtime.nsec .. stat.size
  92. end
  93. error("Could not hash " .. modpath)
  94. end
  95. local function profile(m, entry, name, loader)
  96. if m.profile then
  97. local mp = m.profile
  98. mp[entry] = mp[entry] or {}
  99. if not mp[entry].loader and loader then
  100. mp[entry].loader = loader
  101. end
  102. if not mp[entry][name] then
  103. mp[entry][name] = uv.hrtime()
  104. end
  105. end
  106. end
  107. local function mprofile(mod, name, loader)
  108. profile(M.modpaths, mod, name, loader)
  109. end
  110. local function cprofile(path, name, loader)
  111. if M.chunks.profile then
  112. path = modpath_mangle(path)
  113. end
  114. profile(M.chunks, path, name, loader)
  115. end
  116. function M.enable_profile()
  117. local P = require "lvim.impatient.profile"
  118. M.chunks.profile = {}
  119. M.modpaths.profile = {}
  120. loadlib = function(path, fun)
  121. cprofile(path, "load_start")
  122. local f, err = package.loadlib(path, fun)
  123. cprofile(path, "load_end", "standard")
  124. return f, err
  125. end
  126. P.setup(M.modpaths.profile)
  127. api.nvim_create_user_command("LuaCacheProfile", function()
  128. P.print_profile(M, std_dirs)
  129. end, {})
  130. end
  131. local function get_runtime_file_from_parent(basename, paths)
  132. -- Look in the cache to see if we have already loaded a parent module.
  133. -- If we have then try looking in the parents directory first.
  134. local parents = vim.split(basename, sep)
  135. for i = #parents, 1, -1 do
  136. local parent = table.concat(vim.list_slice(parents, 1, i), sep)
  137. local ppath = M.modpaths:get(parent)
  138. if ppath then
  139. if ppath:sub(-9) == (sep .. "init.lua") then
  140. ppath = ppath:sub(1, -10) -- a/b/init.lua -> a/b
  141. else
  142. ppath = ppath:sub(1, -5) -- a/b.lua -> a/b
  143. end
  144. for _, path in ipairs(paths) do
  145. -- path should be of form 'a/b/c.lua' or 'a/b/c/init.lua'
  146. local modpath = ppath .. sep .. path:sub(#("lua" .. sep .. parent) + 2)
  147. if fs_stat(modpath) then
  148. return modpath, "cache(p)"
  149. end
  150. end
  151. end
  152. end
  153. end
  154. local rtp = vim.split(vim.o.rtp, ",")
  155. -- Make sure modpath is in rtp and that modpath is in paths.
  156. local function validate_modpath(modpath, paths)
  157. local match = false
  158. for _, p in ipairs(paths) do
  159. if vim.endswith(modpath, p) then
  160. match = true
  161. break
  162. end
  163. end
  164. if not match then
  165. return false
  166. end
  167. for _, dir in ipairs(rtp) do
  168. if vim.startswith(modpath, dir) then
  169. return fs_stat(modpath) ~= nil
  170. end
  171. end
  172. return false
  173. end
  174. local function get_runtime_file_cached(basename, paths)
  175. local modpath, loader
  176. local mp = M.modpaths
  177. if mp.enable then
  178. local modpath_cached = mp:get(basename)
  179. if modpath_cached then
  180. modpath, loader = modpath_cached, "cache"
  181. else
  182. modpath, loader = get_runtime_file_from_parent(basename, paths)
  183. end
  184. if modpath and not validate_modpath(modpath, paths) then
  185. modpath = nil
  186. -- Invalidate
  187. mp.cache[basename] = nil
  188. mp.dirty = true
  189. end
  190. end
  191. if not modpath then
  192. -- What Neovim does by default; slowest
  193. modpath, loader = get_runtime(paths, false, { is_lua = true })[1], "standard"
  194. end
  195. if modpath then
  196. mprofile(basename, "resolve_end", loader)
  197. if mp.enable and loader ~= "cache" then
  198. log("Creating cache for module %s", basename)
  199. mp:set(basename, modpath)
  200. mp.dirty = true
  201. end
  202. end
  203. return modpath
  204. end
  205. local function extract_basename(pats)
  206. local basename
  207. -- Deconstruct basename from pats
  208. for _, pat in ipairs(pats) do
  209. for i, npat in ipairs {
  210. -- Ordered by most specific
  211. "lua"
  212. .. sep
  213. .. "(.*)"
  214. .. sep
  215. .. "init%.lua",
  216. "lua" .. sep .. "(.*)%.lua",
  217. } do
  218. local m = pat:match(npat)
  219. if i == 2 and m and m:sub(-4) == "init" then
  220. m = m:sub(0, -6)
  221. end
  222. if not basename then
  223. if m then
  224. basename = m
  225. end
  226. elseif m and m ~= basename then
  227. -- matches are inconsistent
  228. return
  229. end
  230. end
  231. end
  232. return basename
  233. end
  234. local function get_runtime_cached(pats, all, opts)
  235. local fallback = false
  236. if all or not opts or not opts.is_lua then
  237. -- Fallback
  238. fallback = true
  239. end
  240. local basename
  241. if not fallback then
  242. basename = extract_basename(pats)
  243. end
  244. if fallback or not basename then
  245. return get_runtime(pats, all, opts)
  246. end
  247. return { get_runtime_file_cached(basename, pats) }
  248. end
  249. -- Copied from neovim/src/nvim/lua/vim.lua with two lines changed
  250. local function load_package(name)
  251. local basename = name:gsub("%.", sep)
  252. local paths = { "lua" .. sep .. basename .. ".lua", "lua" .. sep .. basename .. sep .. "init.lua" }
  253. -- Original line:
  254. -- local found = vim.api.nvim__get_runtime(paths, false, {is_lua=true})
  255. local found = { get_runtime_file_cached(basename, paths) }
  256. if #found > 0 then
  257. local f, err = loadfile(found[1])
  258. return f or error(err)
  259. end
  260. local so_paths = {}
  261. for _, trail in ipairs(vim._so_trails) do
  262. local path = "lua" .. trail:gsub("?", basename) -- so_trails contains a leading slash
  263. table.insert(so_paths, path)
  264. end
  265. -- Original line:
  266. -- found = vim.api.nvim__get_runtime(so_paths, false, {is_lua=true})
  267. found = { get_runtime_file_cached(basename, so_paths) }
  268. if #found > 0 then
  269. -- Making function name in Lua 5.1 (see src/loadlib.c:mkfuncname) is
  270. -- a) strip prefix up to and including the first dash, if any
  271. -- b) replace all dots by underscores
  272. -- c) prepend "luaopen_"
  273. -- So "foo-bar.baz" should result in "luaopen_bar_baz"
  274. local dash = name:find("-", 1, true)
  275. local modname = dash and name:sub(dash + 1) or name
  276. local f, err = loadlib(found[1], "luaopen_" .. modname:gsub("%.", "_"))
  277. return f or error(err)
  278. end
  279. return nil
  280. end
  281. local function load_from_cache(path)
  282. local mc = M.chunks
  283. local cache = mc:get(path)
  284. if not cache then
  285. return nil, string.format("No cache for path %s", path)
  286. end
  287. local mhash, codes = unpack(cache)
  288. if mhash ~= hash(path) then
  289. mc:set(path)
  290. mc.dirty = true
  291. return nil, string.format("Stale cache for path %s", path)
  292. end
  293. local chunk = loadstring(codes)
  294. if not chunk then
  295. mc:set(path)
  296. mc.dirty = true
  297. return nil, string.format("Cache error for path %s", path)
  298. end
  299. return chunk
  300. end
  301. local function loadfile_cached(path)
  302. cprofile(path, "load_start")
  303. local chunk, err
  304. if M.chunks.enable then
  305. chunk, err = load_from_cache(path)
  306. if chunk and not err then
  307. log("Loaded cache for path %s", path)
  308. cprofile(path, "load_end", "cache")
  309. return chunk
  310. end
  311. log(err)
  312. end
  313. chunk, err = _loadfile(path)
  314. if not err and M.chunks.enable then
  315. log("Creating cache for path %s", path)
  316. M.chunks:set(path, { hash(path), string.dump(chunk) })
  317. M.chunks.dirty = true
  318. end
  319. cprofile(path, "load_end", "standard")
  320. return chunk, err
  321. end
  322. function M.save_cache()
  323. local function _save_cache(t)
  324. if not t.enable then
  325. return
  326. end
  327. if t.dirty then
  328. log("Updating chunk cache file: %s", t.path)
  329. local f = assert(io.open(t.path, "w+b"))
  330. f:write(mpack.encode(t.cache))
  331. f:flush()
  332. t.dirty = false
  333. end
  334. end
  335. _save_cache(M.chunks)
  336. _save_cache(M.modpaths)
  337. end
  338. local function clear_cache()
  339. local function _clear_cache(t)
  340. t.cache = {}
  341. os.remove(t.path)
  342. end
  343. _clear_cache(M.chunks)
  344. _clear_cache(M.modpaths)
  345. end
  346. local function init_cache()
  347. local function _init_cache(t)
  348. if not t.enable then
  349. return
  350. end
  351. if fs_stat(t.path) then
  352. log("Loading cache file %s", t.path)
  353. local f = assert(io.open(t.path, "rb"))
  354. local ok
  355. ok, t.cache = pcall(function()
  356. return mpack.decode(f:read "*a")
  357. end)
  358. if not ok then
  359. log("Corrupted cache file, %s. Invalidating...", t.path)
  360. os.remove(t.path)
  361. t.cache = {}
  362. end
  363. t.dirty = not ok
  364. end
  365. end
  366. if not uv.fs_stat(std_cache) then
  367. vim.fn.mkdir(std_cache, "p")
  368. end
  369. _init_cache(M.chunks)
  370. _init_cache(M.modpaths)
  371. end
  372. local function setup()
  373. init_cache()
  374. -- Usual package loaders
  375. -- 1. package.preload
  376. -- 2. vim._load_package
  377. -- 3. package.path
  378. -- 4. package.cpath
  379. -- 5. all-in-one
  380. -- Override default functions
  381. for i, loader in ipairs(package.loaders) do
  382. if loader == vim._load_package then
  383. package.loaders[i] = load_package
  384. break
  385. end
  386. end
  387. vim._load_package = load_package
  388. vim.api.nvim__get_runtime = get_runtime_cached
  389. loadfile = loadfile_cached
  390. local augroup = api.nvim_create_augroup("impatient", {})
  391. api.nvim_create_user_command("LuaCacheClear", clear_cache, {})
  392. api.nvim_create_user_command("LuaCacheLog", print_log, {})
  393. api.nvim_create_autocmd({ "VimEnter", "VimLeave" }, {
  394. group = augroup,
  395. callback = M.save_cache,
  396. })
  397. api.nvim_create_autocmd("OptionSet", {
  398. group = augroup,
  399. pattern = "runtimepath",
  400. callback = function()
  401. rtp = vim.split(vim.o.rtp, ",")
  402. end,
  403. })
  404. end
  405. setup()
  406. return M