Class: Ignis::AI::TextGenerator
- Inherits:
-
Object
- Object
- Ignis::AI::TextGenerator
- Defined in:
- lib/nnw/ai/inference.rb
Overview
TextGenerator — autoregressive inference with KV cache.
Loads a Transformer model + tokenizer, generates text token-by-token with top-k/top-p sampling and streaming support via EventBus.
Instance Attribute Summary collapse
- #model ⇒ Transformer::Model readonly
- #tokenizer ⇒ Tokenizer readonly
Instance Method Summary collapse
-
#embed(text) ⇒ Array<Float>
Compute embeddings for text (useful for similarity/search).
-
#generate(prompt, max_tokens: 128, temperature: 0.7, top_k: 50, top_p: 0.9, stop_tokens: [], stream: false) {|String| ... } ⇒ String
Generate text from a prompt.
-
#initialize(model, tokenizer) ⇒ TextGenerator
constructor
A new instance of TextGenerator.
Constructor Details
#initialize(model, tokenizer) ⇒ TextGenerator
Returns a new instance of TextGenerator.
18 19 20 21 22 |
# File 'lib/nnw/ai/inference.rb', line 18 def initialize(model, tokenizer) @model = model @tokenizer = tokenizer @model.eval! end |
Instance Attribute Details
#model ⇒ Transformer::Model (readonly)
11 12 13 |
# File 'lib/nnw/ai/inference.rb', line 11 def model @model end |
#tokenizer ⇒ Tokenizer (readonly)
14 15 16 |
# File 'lib/nnw/ai/inference.rb', line 14 def tokenizer @tokenizer end |
Instance Method Details
#embed(text) ⇒ Array<Float>
Compute embeddings for text (useful for similarity/search).
83 84 85 86 87 88 89 90 91 92 |
# File 'lib/nnw/ai/inference.rb', line 83 def (text) Tape.no_grad do input_ids = @tokenizer.encode(text) input_tensor = Tensor.from_host(input_ids, shape: [1, input_ids.length], dtype: :int32, device_id: 0) # Get token embeddings, mean pool tok_emb = @model.instance_variable_get(:@token_embedding).call(input_tensor) tok_emb.mean.to_host end end |
#generate(prompt, max_tokens: 128, temperature: 0.7, top_k: 50, top_p: 0.9, stop_tokens: [], stream: false) {|String| ... } ⇒ String
Generate text from a prompt.
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# File 'lib/nnw/ai/inference.rb', line 34 def generate(prompt, max_tokens: 128, temperature: 0.7, top_k: 50, top_p: 0.9, stop_tokens: [], stream: false, &block) input_ids = @tokenizer.encode(prompt) generated_ids = [] cache = @model.make_kv_cache Tape.no_grad do # Prefill: stream the prompt through the cache one token at a time. Each # decode_step appends that token's K/V; the last token's logits predict # the first generated token. O(prefix) total, like one full forward. last_logits = nil input_ids.each do |tid| break if cache.full? last_logits = @model.decode_step(tid, cache).to_host end # Decode: sample from the running logits, feed the sampled token back to # extend the cache by one (O(prefix) per step, not O(prefix²) re-forward). max_tokens.times do |step| break if last_logits.nil? || cache.full? next_id = sample_token(last_logits, temperature: temperature, top_k: top_k, top_p: top_p) generated_ids << next_id token_text = @tokenizer.decode([next_id], skip_special_tokens: true) block.call(token_text) if stream && block if defined?(Ignis::Shared::EventBus) Ignis::Shared::EventBus.publish(:token_generated, { text: token_text, token_id: next_id, step: step }) end full_text = @tokenizer.decode(generated_ids, skip_special_tokens: true) break if stop_tokens.any? { |s| full_text.include?(s) } break if @tokenizer.special_token_ids.include?(next_id) # EOS # Advance: logits conditioned on prompt + everything generated so far. last_logits = @model.decode_step(next_id, cache).to_host end end @tokenizer.decode(input_ids + generated_ids, skip_special_tokens: true) end |