Class: Kiribi::Gemma4E2B::Text

Inherits:
Base
  • Object
show all
Defined in:
lib/kiribi/gemma4_e2b/text.rb

Constant Summary collapse

FILES =
%w[
  tokenizer.json
  embed_tokens.onnx
  embed_tokens.onnx_data
  embed_tokens.onnx_data_1
  decoder_model_merged.onnx
  decoder_model_merged.onnx_data
  decoder_model_merged.onnx_data_1
  decoder_model_merged.onnx_data_2
  decoder_model_merged.onnx_data_3
  decoder_model_merged.onnx_data_4
].freeze
EOS_TOKEN_NAMES =
%w[<eos> <turn|> <|tool_response>].freeze
IMAGE_TOKEN_ID =
258_880
AUDIO_TOKEN_ID =
258_881

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Methods inherited from Base

download

Constructor Details

#initialize(dest_dir) ⇒ Text

Returns a new instance of Text.



32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
# File 'lib/kiribi/gemma4_e2b/text.rb', line 32

def initialize(dest_dir)
  FILES.each do |f|
    path = File.join(dest_dir, f)
    raise Kiribi::ModelNotDownloaded, %(gemma4-e2b/text: #{f} missing. Run: Kiribi.download("gemma4-e2b/text")) unless File.exist?(path)
  end
  @tokenizer = Tokenizers.from_file(File.join(dest_dir, "tokenizer.json"))
  @eos_token_ids = EOS_TOKEN_NAMES.map { @tokenizer.token_to_id(it) }
  @embed_model = OnnxRuntime::Model.new(File.join(dest_dir, "embed_tokens.onnx"))

  decoder_path = File.join(dest_dir, "decoder_model_merged.onnx")
  @decoder_model = OnnxRuntime::Model.new(decoder_path)

  decoder_sess = OnnxRuntime::InferenceSession.new(decoder_path)
  @head_dims = decoder_sess.inputs
    .select { it[:name].match?(/\Apast_key_values\.\d+\.key\z/) }
    .sort_by { it[:name][/\d+/].to_i }
    .map { it[:shape].last }
  @num_layers = @head_dims.length

  @num_logits_to_keep_1 = OnnxRuntime::OrtValue.from_shape_and_type([], :int64)
  @num_logits_to_keep_1.data_ptr.write_int64(1)
end

Instance Attribute Details

#tokenizerObject (readonly)

Returns the value of attribute tokenizer.



30
31
32
# File 'lib/kiribi/gemma4_e2b/text.rb', line 30

def tokenizer
  @tokenizer
end

Class Method Details

.url_for(filename) ⇒ Object



26
27
28
# File 'lib/kiribi/gemma4_e2b/text.rb', line 26

def self.url_for(filename)
  (filename == "tokenizer.json") ? "#{BASE_URL}/#{filename}" : super
end

Instance Method Details

#chat(messages, max_new_tokens: 256) ⇒ Object



93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# File 'lib/kiribi/gemma4_e2b/text.rb', line 93

def chat(messages, max_new_tokens: 256)
  prompt_parts = ["<bos>"]
  encoded_media = []

  messages.each do |msg|
    role = msg[:role]
    content = msg[:content]
    prompt_parts << "<|turn>#{role}\n"

    if content.is_a?(String)
      prompt_parts << content
    elsif content.is_a?(Array)
      content.each do |part|
        case part[:type]
        when "text"
          prompt_parts << part[:text]
        when "image"
          features = part[:features]
          prompt_parts << "<|image>" + "<|image|>" * features.length + "<image|>\n"
          encoded_media << {token_id: IMAGE_TOKEN_ID, features:}
        when "audio"
          features = part[:features]
          prompt_parts << "<|audio>" + "<|audio|>" * features.length + "<audio|>\n"
          encoded_media << {token_id: AUDIO_TOKEN_ID, features:}
        end
      end
    end

    prompt_parts << "<turn|>\n"
  end
  prompt_parts << "<|turn>model\n"

  input_ids = tokenizer.encode(prompt_parts.join).ids

  embeds = []
  encoded_media.each do |media|
    positions = input_ids.each_with_index
      .select { |t, _| t == media[:token_id] }
      .map(&:last)
      .reject { |pos| embeds.any? { it[:pos] == pos } }
    media[:features].each_with_index do |feat, idx|
      break if idx >= positions.length
      embeds << {pos: positions[idx], feat:}
    end
  end

  past_kv = nil
  generated = []

  max_new_tokens.times do |step|
    cur_ids = (step == 0) ? input_ids : [generated.last]
    seq_len = cur_ids.length
    total_len = input_ids.length + generated.length

    embed_out = embed(cur_ids)
    inputs_embeds = embed_out["inputs_embeds"]
    per_layer_inputs = embed_out["per_layer_inputs"]

    if step == 0
      embeds.each { inputs_embeds[0][it[:pos]] = it[:feat] }
    end

    result = forward(
      inputs_embeds:,
      per_layer_inputs:,
      attention_mask: [Array.new(total_len, 1)],
      position_ids: [(total_len - seq_len...total_len).to_a],
      past_key_values: past_kv
    )
    past_kv = result[:past_key_values]

    next_token = result[:logits][0][-1].each_with_index.max_by { |v, _| v }[1]
    break if @eos_token_ids.include?(next_token)
    generated << next_token
  end

  tokenizer.decode(generated)
end

#embed(input_ids) ⇒ Object



55
56
57
# File 'lib/kiribi/gemma4_e2b/text.rb', line 55

def embed(input_ids)
  @embed_model.predict({"input_ids" => [input_ids]})
end

#forward(inputs_embeds:, per_layer_inputs:, attention_mask:, position_ids:, past_key_values: nil) ⇒ Object



59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# File 'lib/kiribi/gemma4_e2b/text.rb', line 59

def forward(inputs_embeds:, per_layer_inputs:, attention_mask:, position_ids:, past_key_values: nil)
  past_kv = past_key_values || init_kv_cache
  input = {
    "inputs_embeds" => inputs_embeds,
    "attention_mask" => attention_mask,
    "position_ids" => position_ids,
    "num_logits_to_keep" => @num_logits_to_keep_1,
    "per_layer_inputs" => per_layer_inputs
  }
  input.merge!(past_kv)
  out = @decoder_model.predict(input)

  new_kv = {}
  @num_layers.times do |i|
    new_kv["past_key_values.#{i}.key"] = out["present.#{i}.key"]
    new_kv["past_key_values.#{i}.value"] = out["present.#{i}.value"]
  end

  {logits: out["logits"], past_key_values: new_kv}
end

#generate(prompt, max_new_tokens: 256) ⇒ Object



89
90
91
# File 'lib/kiribi/gemma4_e2b/text.rb', line 89

def generate(prompt, max_new_tokens: 256)
  chat([{role: "user", content: prompt}], max_new_tokens:)
end

#init_kv_cacheObject



80
81
82
83
84
85
86
87
# File 'lib/kiribi/gemma4_e2b/text.rb', line 80

def init_kv_cache
  kv = {}
  @num_layers.times do |i|
    kv["past_key_values.#{i}.key"] = OnnxRuntime::OrtValue.from_shape_and_type([1, 1, 0, @head_dims[i]], :float)
    kv["past_key_values.#{i}.value"] = OnnxRuntime::OrtValue.from_shape_and_type([1, 1, 0, @head_dims[i]], :float)
  end
  kv
end