Class: Kumi::Core::Analyzer::Passes::SNASTPass
- Defined in:
- lib/kumi/core/analyzer/passes/snast_pass.rb
Constant Summary
Constants inherited from PassBase
Instance Method Summary collapse
- #axes_of(n) ⇒ Object
- #dtype_of(n) ⇒ Object
- #lookup_input(fqn) ⇒ Object
-
#lub_by_prefix(list) ⇒ Object
Least upper bound by prefix.
- #meta_for(node) ⇒ Object
- #node_key(n) ⇒ Object
- #prefix?(pre, full) ⇒ Boolean
- #run(errors) ⇒ Object
-
#stamp!(node, axes, dtype) ⇒ Object
———- Helpers ———-.
- #visit_call(n) ⇒ Object
-
#visit_const(n) ⇒ Object
———- Leaves ———-.
- #visit_declaration(d) ⇒ Object
- #visit_hash(n) ⇒ Object
- #visit_import_call(n) ⇒ Object
- #visit_index_ref(n) ⇒ Object
- #visit_input_ref(n) ⇒ Object
-
#visit_module(mod) ⇒ Object
———- Visitor entry points ———-.
- #visit_pair(n) ⇒ Object
- #visit_ref(n) ⇒ Object
- #visit_tuple(n) ⇒ Object
Methods inherited from PassBase
contract_declared?, #debug, #debug_enabled?, declared_optional_reads, declared_reads, declared_writes, #initialize, optional_reads, reads, writes
Methods included from ErrorReporting
#inferred_location, #raise_localized_error, #raise_syntax_error, #raise_type_error, #report_enhanced_error, #report_error, #report_semantic_error, #report_syntax_error, #report_type_error
Constructor Details
This class inherits a constructor from Kumi::Core::Analyzer::Passes::PassBase
Instance Method Details
#axes_of(n) ⇒ Object
189 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 189 def axes_of(n) = Array(n.[:stamp]&.dig(:axes)) |
#dtype_of(n) ⇒ Object
190 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 190 def dtype_of(n) = n.[:stamp]&.dig(:dtype) |
#lookup_input(fqn) ⇒ Object
207 208 209 210 211 212 213 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 207 def lookup_input(fqn) if @input_table.respond_to?(:find) @input_table.find { |x| x[:path_fqn] == fqn } || raise("Input not found for #{fqn}") else @input_table.fetch(fqn) { raise("Input not found for #{fqn}") } end end |
#lub_by_prefix(list) ⇒ Object
Least upper bound by prefix. All entries must be a prefix of the longest.
193 194 195 196 197 198 199 200 201 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 193 def lub_by_prefix(list) return [] if list.empty? cand = list.max_by(&:length) || [] list.each do |ax| raise Kumi::Core::Errors::CompilerBug, "axis prefix mismatch: #{ax.inspect} vs #{cand.inspect}" unless prefix?(ax, cand) end cand end |
#meta_for(node) ⇒ Object
188 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 188 def (node) = @metadata_table.fetch(node_key(node)) |
#node_key(n) ⇒ Object
215 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 215 def node_key(n) = "#{n.class}_#{n.id}" |
#prefix?(pre, full) ⇒ Boolean
203 204 205 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 203 def prefix?(pre, full) pre.each_with_index.all? { |tok, i| full[i] == tok } end |
#run(errors) ⇒ Object
10 11 12 13 14 15 16 17 18 19 20 21 22 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 10 def run(errors) @nast_module = get_state(:nast_module, required: true) @metadata_table = get_state(:metadata_table, required: true) @declaration_table = get_state(:declaration_table, required: true) @input_table = get_state(:input_table, required: true) @index_table = get_state(:index_table, required: true) @registry = get_state(:registry, required: true) @errors = errors debug "Building SNAST from #{@nast_module.decls.size} declarations" snast_module = @nast_module.accept(self) state.with(:snast_module, snast_module.freeze) end |
#stamp!(node, axes, dtype) ⇒ Object
———- Helpers ———-
183 184 185 186 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 183 def stamp!(node, axes, dtype) node.[:stamp] = { axes: Array(axes), dtype: dtype }.freeze node end |
#visit_call(n) ⇒ Object
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 85 def visit_call(n) if @registry.select?(n.fn) c = n.args[0].accept(self) t = n.args[1].accept(self) f = n.args[2].accept(self) target_axes = lub_by_prefix([axes_of(t), axes_of(f)]) target_axes = axes_of(c) if target_axes.empty? unless prefix?(axes_of(c), target_axes) halt_pass!(@errors, "select mask axes #{axes_of(c).inspect} must prefix #{target_axes.inspect}", location: n.loc) end out = NAST::Select.new(id: n.id, cond: c, on_true: t, on_false: f, loc: n.loc, meta: n..dup) return stamp!(out, target_axes, dtype_of(t)) end if @registry.reduce?(n.fn) # Reduce arity is fixed upstream; >1 arg here means the IR is malformed. raise Kumi::Core::Errors::CompilerBug, "reduce #{n.fn} has #{n.args.size} args, expected 1" if n.args.size != 1 arg_node = n.args.first visited_arg = arg_node.accept(self) = visited_arg[:meta] arg_type = [:stamp][:dtype] if Kumi::Core::Types.collection?(arg_type) # --- Path for FOLD (Scalar or Vectorized) ---w # The argument is semantically a tuple. Create a Fold node. # We still need to visit the child node to build the SNAST tree = (n) fold_node = NAST::Fold.new( id: n.id, fn: .fetch(:function).to_sym, arg: visited_arg, # The arg is the tuple/reference to the tuple loc: n.loc, meta: n..dup ) # The output type is the reduced scalar type (e.g., :integer for max). # The axes are PRESERVED because a fold is an element-wise operation # on the container of tuples. return stamp!(fold_node, [:result_scope], [:result_type]) else # --- Path for REDUCE (Vectorized Arrays) --- in_axes = axes_of(visited_arg) halt_pass!(@errors, "reduce function called on a non-collection scalar: #{arg_type}", location: n.loc) if in_axes.empty? = (n) out_axes = Array([:result_scope]) unless prefix?(out_axes, in_axes) halt_pass!(@errors, "reduce: out axes #{out_axes.inspect} must prefix arg axes #{in_axes.inspect}", location: n.loc) end over_axes = in_axes.drop(out_axes.length) reduce_node = NAST::Reduce.new( id: n.id, fn: .fetch(:function).to_sym, over: over_axes, arg: visited_arg, loc: n.loc, meta: n..dup ) return stamp!(reduce_node, out_axes, [:result_type]) end end # regular elementwise: the function id was resolved with type # awareness in NASTDimensionalAnalyzerPass and stored in metadata. args = n.args.map { _1.accept(self) } m = (n) out = n.class.new(id: n.id, fn: m.fetch(:function).to_sym, args:, opts: n.opts, loc: n.loc) stamp!(out, m[:result_scope], m[:result_type]) end |
#visit_const(n) ⇒ Object
———- Leaves ———-
40 41 42 43 44 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 40 def visit_const(n) = (n) out = n.class.new(id: n.id, value: n.value, loc: n.loc) stamp!(out, [], [:type]) end |
#visit_declaration(d) ⇒ Object
31 32 33 34 35 36 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 31 def visit_declaration(d) = @declaration_table.fetch(d.name) body = d.body.accept(self) out = d.class.new(id: d.id, name: d.name, body:, loc: d.loc, meta: { kind: d.kind }) stamp!(out, [:result_scope], [:result_type]) end |
#visit_hash(n) ⇒ Object
71 72 73 74 75 76 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 71 def visit_hash(n) pairs = n.pairs.map { _1.accept(self) } m = (n) out = n.class.new(id: n.id, pairs:, loc: n.loc) stamp!(out, m[:scope], m[:type]) end |
#visit_import_call(n) ⇒ Object
166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 166 def visit_import_call(n) args = n.args.map { _1.accept(self) } m = (n) out = n.class.new( id: n.id, fn_name: n.fn_name, args: args, input_mapping_keys: n.input_mapping_keys, source_module: n.source_module, loc: n.loc, meta: n..dup ) stamp!(out, m[:result_scope], m[:result_type]) end |
#visit_index_ref(n) ⇒ Object
52 53 54 55 56 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 52 def visit_index_ref(n) m = (n) out = n.class.new(id: n.id, name: n.name, input_fqn: n.input_fqn, loc: n.loc) stamp!(out, m[:scope], m[:type]) end |
#visit_input_ref(n) ⇒ Object
46 47 48 49 50 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 46 def visit_input_ref(n) ent = lookup_input(n.path_fqn) out = n.class.new(id: n.id, path: n.path, loc: n.loc) stamp!(out, ent[:axes], ent[:dtype]) end |
#visit_module(mod) ⇒ Object
———- Visitor entry points ———-
26 27 28 29 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 26 def visit_module(mod) # decls is expected to be a Hash[name => Declaration] mod.class.new(decls: mod.decls.transform_values { |d| d.accept(self) }) end |
#visit_pair(n) ⇒ Object
78 79 80 81 82 83 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 78 def visit_pair(n) value = n.value.accept(self) m = (n) out = n.class.new(id: n.id, key: n.key, value:) stamp!(out, m[:scope], m[:type]) end |
#visit_ref(n) ⇒ Object
58 59 60 61 62 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 58 def visit_ref(n) m = (n) out = n.class.new(id: n.id, name: n.name, loc: n.loc) stamp!(out, m[:result_scope], m[:result_type]) end |
#visit_tuple(n) ⇒ Object
64 65 66 67 68 69 |
# File 'lib/kumi/core/analyzer/passes/snast_pass.rb', line 64 def visit_tuple(n) args = n.args.map { _1.accept(self) } m = (n) out = n.class.new(id: n.id, args:, loc: n.loc) stamp!(out, m[:result_scope], m[:result_type]) end |