-- dmlib/sourceoptimizer.mlua -- Automatic source optimier. -- Inlining of functions and partial evaluation. -- -- For example y = (function(x) return x*x end)(2) -- simplifies to local __v1x=x; y=__v1x*__v1x . -- -- Warning: may not work correctly for odd cases where inlined function -- uses setfenv or debug functions that utilize stack levels. -- -- Warning: code is beta quality. See test suite. -- -- Uses metalua 0.5 git branch (20090228) -- -- (c) 2009 David Manura, licensed under the same terms as Lua (MIT license). -- require 'metalua.walk.id' require 'metalua.ast_to_string' local M = {} M.internal = {} -- Print values, for debugging only. local function DEBUG(...) local ts = {} for i,v in ipairs{...} do ts[i] = type(v) == 'table' and table.tostring(v,'nohash') or tostring(v) end print('DEBUG:' .. table.concat(ts,', ')) end -- tinsertvalues(t, [pos,] values) -- This is similar to table.insert but inserts values from given table "values", -- not the object itself, into table "t" at position "pos". -- note: an optional extension is to allow selection of a slice of values: -- tinsertvalues(t, [pos, [vpos1, [vpos2, ]]] values) -- http://lua-users.org/wiki/TableUtils function tinsertvalues(t, ...) local pos, values if select('#', ...) == 1 then pos,values = #t+1, ... else pos,values = ... end if #values > 0 then for i=#t,pos,-1 do t[i+#values] = t[i] end local offset = 1 - pos for i=pos,pos+#values-1 do t[i] = values[i + offset] end end end --[[ tests: local t = {5,6,7} tinsertvalues(t, {8,9}) tinsertvalues(t, {}) tinsertvalues(t, 1, {1,4}) tinsertvalues(t, 2, {2,3}) assert(table.concat(t, '') == '123456789') --]] -- Returns i such that t[i] == v. nil on no match local function tindex(t, v) for i=1,#t do if t[i] == v then return i end end end -- Similar to table.override, but clears hash part not in src -- (bug in table.override? -- FIX) local function toverride(dst, src) for k in pairs(dst) do dst[k] = src[k] end for k, v in pairs(src) do dst[k] = v end return dst end -- Returns t[k]. If undefined, first sets t[k] to vdef. local function tdefault(t, k, vdef) local v = t[k] if v == nil then t[k] = vdef; v = vdef end return v end -- Converts Lua source string to Lua AST (via mlp/gg) local function string_to_ast(src, filename) filename = filename or '(string)' local lx = mlp.lexer:newstream (src, filename) local ast = mlp.chunk(lx) return ast end -- Gets last statement in block AST. -- Ignores redundant nested {} local function block_last(ast) local last = ast[#ast] if last and last.tag == nil then return block_last(last) else return last end end -- Marks local variables that are constant (not assigned to again). -- Returns table mapping of used variable ID AST to variable table. -- Variable table contains field 'const' if constant. local function mark_const_locals(ast) local scopes = {} local vars = {} local var_from_id_ast = {} local cfg = {id={}} function cfg.id.bound(id_ast, binder, parent_ast) local id = id_ast[1] local scope = tdefault(scopes, binder, {}) local var = tdefault(scope, id, {const = true}) var_from_id_ast[id_ast] = var if parent_ast.tag == 'Set' then if tindex(parent_ast[1], id_ast) then -- on LHS var.const = false end end end walk_id.guess(cfg, ast) return var_from_id_ast end M.mark_const_locals = mark_const_locals -- Appends block of code "ast" to preblock of AST. -- The "preblock" stores statements that should be inserted -- before the current statement. local function addpreblock(ast, pre_ast) ast.preblock = ast.preblock or {} tinsertvalues(ast.preblock, pre_ast) end -- Applies function, using inlining. -- note: f_ast and args_ast are unchanged -- Assumes inlining into location supporting minimum "nmin" and -- maximum "nmax" values. -- Uses "gensym" function to generate unique symbol. local function apply_func(f_ast, args_ast, nmin, nmax, gensym, id) -- DEBUG(ast_to_string(f_ast), ast_to_string(args_ast)) assert(f_ast.tag == 'Function') args_ast = table.deep_copy(args_ast) local ast = table.deep_copy(f_ast) local params_ast = ast[1] local stats_ast = ast[2] -- local ret_id = gensym(id) local replaces = {} -- rename bound variables (leave free variables unchanged) -- careful: easier to delay modifying AST until after traversal local cfg = {id={}} local is_var = {} function cfg.id.bound(x, binder) -- DEBUG('VB', x, binder) is_var[x] = true end -- function cfg.id.free(x) -- DEBUG('VF', x) -- end function cfg.binder(x) -- DEBUG('VD', x) is_var[x] = true end walk_id.guess(cfg, ast) for var_ast in pairs(is_var) do local id = var_ast[1] replaces[id] = replaces[id] or gensym(id) var_ast[1] = replaces[id] end -- replace returns (only special case) -- local val_ast -- local cfg = {stat={},expr={}} -- function cfg.stat.down(x) -- if x.tag == 'Return' and x ~= stats_ast[#stats_ast] then -- toverride(x, `Local{{`Id{ret_id}}, {unpack(x)}}) -- end -- end -- function cfg.expr.down(x) -- don't decend into nested -- if x.tag == 'Function' then return 'break' end -- end -- walk.guess(cfg, stats_ast) local return_ast = #stats_ast > 0 and stats_ast[#stats_ast].tag == 'Return' and stats_ast[#stats_ast] or `Return{} -- Modify return at the end of the block. --DEBUG(nmin, nmax,#return_ast) if #return_ast > nmax then val_ast = nmax == 1 and return_ast[1] or {unpack(return_ast, 1, nmax)} local block_ast = {} local elist_ast = {} for i=nmax+1,#return_ast do local rv_ast = return_ast[i] if rv_ast.tag == 'Call' or rv_ast.tag == 'Invoke' then table.insert(block_ast, rv_ast) -- usable as statement else table.insert(elist_ast, rv_ast) end end if #elist_ast > 0 then table.insert(block_ast, `Do{`Local{{`Id'_'}, elist_ast}}) end toverride(return_ast, block_ast) elseif #return_ast < nmin then val_ast = nmin == 1 and `Nil or {unpack(return_ast)} toverride(return_ast, {}) else val_ast = #return_ast == 1 and return_ast[1] or {unpack(return_ast)} toverride(return_ast, {}) end val_ast.preblock = nil -- preblock for return already evaluated -- build preblock local pre_ast = {} for _,arg_ast in ipairs(args_ast) do tinsertvalues(pre_ast, arg_ast.preblock or {}) arg_ast.preblock = nil end if #params_ast > 0 then local vs_ast = {} table.insert(pre_ast, `Local{params_ast, args_ast}) end tinsertvalues(pre_ast, stats_ast) -- finish value addpreblock(val_ast, pre_ast) -- DEBUG('inline',val_ast, val_ast.preblock) -- DEBUG('o', ast_to_string(val_ast.preblock), ast_to_string(val_ast)) return val_ast end -- Gets minimum and maximum number of values that expression "ast" -- inside parent_ast can expand into. -- E.g. g(1,f()) is 0,math.huge; g(f(),1) is 1,1 local function nret_position(ast, parent_ast) if parent_ast.tag == 'Table' or parent_ast.tag == 'Call' or parent_ast.tag == 'Invoke' or parent_ast.tag == 'Return' then if parent_ast[#parent_ast] == ast then return 0,math.huge else return 1,1 end elseif parent_ast.tag == 'Set' or parent_ast.tag == 'Local' then local rhs_ast = parent_ast[2] if rhs_ast[#rhs_ast] == ast then return (#rhs_ast == 1 and parent_ast.tag == 'Set' and 1 or 0),math.huge else return 1,1 end elseif parent_ast.tag == 'Forin' then local it_ast = parent_ast[2] if it_ast[#it_ast] == ast then return (#it_ast == 1 and 1 or 0), 3-(#it_ast-1) else return 1,1 end else -- others, including Fornum return 1,1 end end -- Flattens expression list expanded from expression -- into containing AST. local function flatten_ast(ast) local op_ast if ast.tag == 'Table' or ast.tag == 'Call' or ast.tag == 'Invoke' or ast.tag == 'Return' then op_ast = ast elseif ast.tag == 'Set' or ast.tag == 'Local' or ast.tag == 'Forin' then op_ast = ast[2] else assert(false, ast.tag) end local vals_ast = table.remove(op_ast, #op_ast) tinsertvalues(op_ast, vals_ast) -- note: expression flattens into an empty expression list, there -- may be no expression to attach the preblock to. Therefore -- attach it to the containing expression list. if vals_ast.preblock then op_ast.preblock_extra = vals_ast.preblock end -- assignment must have >= 1 element on RHS --old if ast.tag == 'Set' and #ast[2] == 0 then -- table.insert(ast[2], `Nil) -- end end -- Flattens block AST (removing unnecessary nested {} ) local function flatten_block(ast) assert(ast.tag == nil) repeat local done = true for i,v_ast in ipairs(ast) do if v_ast.tag == nil then table.remove(ast, i) tinsertvalues(ast, i, v_ast) done = false break end end until done end -- Eliminates local function definitions that are not used -- in ast. local function eliminate_dead_functions(ast) local used = {} local cfg = {id={}} function cfg.id.bound(id_ast, binder_ast) used[binder_ast] = true --DEBUG('used', id_ast[1]) end walk_id.guess(cfg, ast) local cfg = {stat={}} function cfg.stat.up(ast) if ast.tag=='Localrec' and #ast[2]==1 and ast[2][1].tag=='Function' then local def_ast = ast[2][1] if not used[ast] then --DEBUG('delete', def_ast[1][1][1]) toverride(ast, {}) end end end walk.guess(cfg, ast) end M.eliminate_dead_functions = eliminate_dead_functions -- Find unique prefix that no variables in AST have local function get_unique_prefix(ast) local prefix = '__v' local cfg = {id={}} function cfg.id.free(x) if x.tag == 'Dots' then return end local id = x[1] while id:find(prefix) == 1 do prefix = prefix .. 'v' end end cfg.id.bound = cfg.id.free cfg.binder = cfg.id.free walk_id.guess(cfg, ast) return prefix end -- Inlines functions in AST (when possible). local function inline_functions(ast) -- funcions that are const local may be safe to inline. find them. local var_from_id_ast = mark_const_locals(ast) -- Returns function definition from ast expression -- if it is an inlinable function. -- note: ignoring parens, lookup IDs -- note: possible should make paren removal a separate tree pass local function is_inlinable_function(ast) local def_ast, id_ast def_ast = ast.tag == 'Function' and ast or ast.tag == 'Paren' and is_inlinable_function(ast[1]) or ast.tag == 'Id' and (ast.declaration or {}).tag=='Localrec' and ast.declaration[2][1].tag == 'Function' and ast.declaration[2][1] if def_ast then local params_ast = def_ast[1] if (params_ast[#params_ast] or {}).tag == 'Dots' then return --IMPROVE:NOT IMPL end end id_ast = def_ast and ast.tag == 'Id' and ast if not def_ast then return end if ast.tag == 'Id' and not var_from_id_ast[ast].const then return end -- allow only return at end of function local first_return local cfg = {stat={}, expr={}} function cfg.stat.down(ast) if ast.tag == 'Return' then first_return = first_return or ast end end function cfg.expr.down(ast) if ast.tag == 'Function' then return 'break' end end local body_ast = def_ast[2] walk.guess(cfg, body_ast) if first_return and first_return ~= block_last(body_ast) then return end return def_ast, id_ast end local prefix = get_unique_prefix(ast) -- Create unique symbol. local lastid = 0 local function gensym(postfix) postfix = (postfix or ''):gsub('^' .. prefix .. '%d+', '') lastid = lastid + 1 return prefix .. lastid .. postfix end local cfg = {stat={},expr={},id={}} function cfg.id.bound(id_ast, binder_ast) id_ast.declaration = binder_ast end function cfg.expr.up(ast, parent_ast) assert(parent_ast) -- cleanup if ast.inline_flatten then flatten_ast(ast) end if ast.tag == 'Function' then flatten_block(ast[2]) end if ast.tag == 'Call' then local f_ast, id_ast = is_inlinable_function(ast[1]) if f_ast then local nmin, nmax = nret_position(ast, parent_ast) local id = id_ast and id_ast[1] local res_ast = apply_func(f_ast, {unpack(ast, 2)}, nmin,nmax,gensym, id) toverride(ast, res_ast) if ast.tag == nil then parent_ast.inline_flatten = true end end elseif ast.tag == 'Op' then if ast[1] == 'or' then if ast.preblock then local a_ast, b_ast = ast[2], ast[3] local val_id = gensym() local val_id_ast = `Id{val_id} toverride(ast, val_id_ast) local pre_ast = {} if a_ast.preblock then tinsertvalues(pre_ast, a_ast.preblock) end tinsertvalues(pre_ast, +{block: local -{val_id_ast} = -{a_ast} if not -{val_id_ast} then -{ {b_ast.preblock} }; -- note: preblock may be nil -{ `Set{{val_id_ast}, {b_ast}} } end }) ast.preblock = pre_ast end elseif ast[1] == 'and' then if ast.preblock then local a_ast, b_ast = ast[2], ast[3] local val_id = gensym() local val_id_ast = `Id{val_id} toverride(ast, val_id_ast) local pre_ast = {} if a_ast.preblock then tinsertvalues(pre_ast, a_ast.preblock) end tinsertvalues(pre_ast, +{block: local -{val_id_ast} = -{a_ast} if -{val_id_ast} then -{ {b_ast.preblock} }; -- note: preblock may be nil -{ `Set{{val_id_ast}, {b_ast}} } end }) ast.preblock = pre_ast end end end if ast.preblock then addpreblock(parent_ast, ast.preblock) end end function cfg.stat.up(ast, parent_ast,debug1) assert(parent_ast) -- cleanup if ast.inline_flatten then flatten_ast(ast) end -- inline statement calls if ast.tag == 'Call' then local f_ast, id_ast = is_inlinable_function(ast[1]) if f_ast then local nmin, nmax = 0,0 local id = id_ast and id_ast[1] local res_ast = apply_func(f_ast, {unpack(ast, 2)}, nmin,nmax,gensym,id) assert(#res_ast == 0) local block_ast = {} tinsertvalues(block_ast, res_ast.preblock or {}) toverride(ast, block_ast) end end -- expand preblocks if ast.preblock and not ast.ignore then if ast.tag == 'Set' or ast.tag == 'Local' or ast.tag == 'Call' or ast.tag == 'Return' -- IMPROVE? tail call? then -- careful: modifying AST during traversal local idx = tindex(parent_ast, ast) -- insert preblock statements before current statement local new_ast = {} tinsertvalues(new_ast, ast.preblock or {}) new_ast[#new_ast+1] = ast parent_ast[idx] = new_ast -- safer than insert ast.ignore = true elseif ast.tag == 'While' then -- while x do y end --> while 1 do if not x then break end y end local cond_ast = ast[1] ast[1] = `Number 1 local block_ast = ast[2] table.insert(block_ast, 1, { ast.preblock or {}, `If{`Op{'not', cond_ast}, {`Break} } }) elseif ast.tag == 'Repeat' then local block_ast, cond_ast = ast[1], ast[2] table.insert(block_ast, cond_ast.preblock or {}) elseif ast.tag == 'If' then local cur_ast = ast repeat local done = true if cur_ast[1].preblock then toverride(cur_ast, {cur_ast[1].preblock, table.shallow_copy(cur_ast)}) cur_ast = ast[2] end for i=3,#cur_ast,2 do if cur_ast[i].preblock then toverride(cur_ast[i], `If{table.shallow_copy(cur_ast[i]), unpack(cur_ast, i+1)}) for j=i+1,#cur_ast do cur_ast[j] = nil end cur_ast = cur_ast[i] done = false break end end until done elseif ast.tag == 'Fornum' then local block_ast = {} tinsertvalues(block_ast, ast[2].preblock or {}) tinsertvalues(block_ast, ast[3].preblock or {}) tinsertvalues(block_ast, #ast == 5 and ast[4].preblock or {}) if #block_ast > 0 then table.insert(block_ast, table.shallow_copy(ast)) toverride(ast, block_ast) end elseif ast.tag == 'Forin' then local block_ast = {} local it_ast = ast[2] tinsertvalues(block_ast, it_ast[1].preblock or {}) tinsertvalues(block_ast, #it_ast >= 2 and it_ast[2].preblock or {}) tinsertvalues(block_ast, #it_ast == 3 and it_ast[3].preblock or {}) tinsertvalues(block_ast, it_ast.preblock_extra or {}) if #block_ast > 0 then table.insert(block_ast, table.shallow_copy(ast)) toverride(ast, block_ast) end else assert(false, ast.tag) end end end -- inline functions walk_id.guess(cfg, ast) -- remove definitions of functions that are fully inlined eliminate_dead_functions(ast) return ast end M.inline_functions = inline_functions -- Same as inline_functions get inputs and outputs code string. local function inline_functions_code(code) local ast = string_to_ast(code) -- DEBUG('before:', ast) inline_functions(ast) -- DEBUG('after:', ast) code = ast_to_string(ast) return code end M.inline_functions_code = inline_functions_code return M