Class: Ignis::AI::TextGenerator

Inherits:
Object
  • Object
show all
Defined in:
lib/nnw/ai/inference.rb

Overview

TextGenerator — autoregressive inference with KV cache.

Loads a Transformer model + tokenizer, generates text token-by-token with top-k/top-p sampling and streaming support via EventBus.

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(model, tokenizer) ⇒ TextGenerator

Returns a new instance of TextGenerator.

Parameters:



18
19
20
21
22
# File 'lib/nnw/ai/inference.rb', line 18

def initialize(model, tokenizer)
  @model = model
  @tokenizer = tokenizer
  @model.eval!
end

Instance Attribute Details

#modelTransformer::Model (readonly)

Returns:



11
12
13
# File 'lib/nnw/ai/inference.rb', line 11

def model
  @model
end

#tokenizerTokenizer (readonly)

Returns:



14
15
16
# File 'lib/nnw/ai/inference.rb', line 14

def tokenizer
  @tokenizer
end

Instance Method Details

#embed(text) ⇒ Array<Float>

Compute embeddings for text (useful for similarity/search).

Parameters:

  • text (String)

Returns:

  • (Array<Float>)

    embedding vector



83
84
85
86
87
88
89
90
91
92
# File 'lib/nnw/ai/inference.rb', line 83

def embed(text)
  Tape.no_grad do
    input_ids = @tokenizer.encode(text)
    input_tensor = Tensor.from_host(input_ids, shape: [1, input_ids.length],
                                     dtype: :int32, device_id: 0)
    # Get token embeddings, mean pool
    tok_emb = @model.instance_variable_get(:@token_embedding).call(input_tensor)
    tok_emb.mean.to_host
  end
end

#generate(prompt, max_tokens: 128, temperature: 0.7, top_k: 50, top_p: 0.9, stop_tokens: [], stream: false) {|String| ... } ⇒ String

Generate text from a prompt.

Parameters:

  • prompt (String)

    input text

  • max_tokens (Integer) (defaults to: 128)

    maximum new tokens to generate

  • temperature (Float) (defaults to: 0.7)

    sampling temperature (0.0 = greedy)

  • top_k (Integer) (defaults to: 50)

    top-k sampling (0 = disabled)

  • top_p (Float) (defaults to: 0.9)

    nucleus sampling threshold (1.0 = disabled)

  • stop_tokens (Array<String>) (defaults to: [])

    strings that trigger early stop

  • stream (Boolean) (defaults to: false)

    yield tokens as they are generated

Yields:

  • (String)

    each generated token (when stream: true)

Returns:

  • (String)

    full generated text



34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/nnw/ai/inference.rb', line 34

def generate(prompt, max_tokens: 128, temperature: 0.7,
             top_k: 50, top_p: 0.9, stop_tokens: [], stream: false, &block)
  input_ids = @tokenizer.encode(prompt)
  generated_ids = []
  cache = @model.make_kv_cache

  Tape.no_grad do
    # Prefill: stream the prompt through the cache one token at a time. Each
    # decode_step appends that token's K/V; the last token's logits predict
    # the first generated token. O(prefix) total, like one full forward.
    last_logits = nil
    input_ids.each do |tid|
      break if cache.full?
      last_logits = @model.decode_step(tid, cache).to_host
    end

    # Decode: sample from the running logits, feed the sampled token back to
    # extend the cache by one (O(prefix) per step, not O(prefix²) re-forward).
    max_tokens.times do |step|
      break if last_logits.nil? || cache.full?

      next_id = sample_token(last_logits, temperature: temperature,
                             top_k: top_k, top_p: top_p)
      generated_ids << next_id

      token_text = @tokenizer.decode([next_id], skip_special_tokens: true)
      block.call(token_text) if stream && block

      if defined?(Ignis::Shared::EventBus)
        Ignis::Shared::EventBus.publish(:token_generated, {
          text: token_text, token_id: next_id, step: step
        })
      end

      full_text = @tokenizer.decode(generated_ids, skip_special_tokens: true)
      break if stop_tokens.any? { |s| full_text.include?(s) }
      break if @tokenizer.special_token_ids.include?(next_id) # EOS

      # Advance: logits conditioned on prompt + everything generated so far.
      last_logits = @model.decode_step(next_id, cache).to_host
    end
  end

  @tokenizer.decode(input_ids + generated_ids, skip_special_tokens: true)
end