Class: Kiribi::MultilingualE5Small

Inherits:
Object
  • Object
show all
Defined in:
lib/kiribi/multilingual_e5_small.rb

Constant Summary collapse

ONNX_FILE =
"model_qint8_avx512_vnni.onnx"
FILES =
[ONNX_FILE, "tokenizer.json"].freeze
URL =
"https://github.com/matsudai/kiribi-externals/releases/download/intfloat%2Fmultilingual-e5-small%2Fc007d7e/model_qint8_avx512_vnni.tar.gz"

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(dest_dir) ⇒ MultilingualE5Small

Returns a new instance of MultilingualE5Small.



40
41
42
43
44
45
46
47
# File 'lib/kiribi/multilingual_e5_small.rb', line 40

def initialize(dest_dir)
  FILES.each do |f|
    path = File.join(dest_dir, f)
    raise Kiribi::ModelNotDownloaded, %(multilingual-e5-small: #{f} missing. Run: Kiribi.download("multilingual-e5-small")) unless File.exist?(path)
  end
  @tokenizer = Tokenizers.from_file(File.join(dest_dir, "tokenizer.json"))
  @onnx_model = OnnxRuntime::Model.new(File.join(dest_dir, ONNX_FILE))
end

Instance Attribute Details

#onnx_modelObject (readonly)

Returns the value of attribute onnx_model.



38
39
40
# File 'lib/kiribi/multilingual_e5_small.rb', line 38

def onnx_model
  @onnx_model
end

#tokenizerObject (readonly)

Returns the value of attribute tokenizer.



38
39
40
# File 'lib/kiribi/multilingual_e5_small.rb', line 38

def tokenizer
  @tokenizer
end

Class Method Details

.download(dest_dir, force: false) ⇒ Object



17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
# File 'lib/kiribi/multilingual_e5_small.rb', line 17

def self.download(dest_dir, force: false)
  return if !force && FILES.all? { |f| File.exist?(File.join(dest_dir, f)) }
  FileUtils.rm_rf(dest_dir) if force
  FileUtils.mkdir_p(dest_dir)

  io = StringIO.new
  Kiribi.http_get(URL) { |chunk| io.write(chunk) }
  io.rewind

  Gem::Package::TarReader.new(Zlib::GzipReader.new(io)) do |tar|
    tar.each do |entry|
      next unless entry.file?
      name = Pathname(entry.full_name).each_filename.to_a[1..].join("/")
      next unless FILES.include?(name)
      path = File.join(dest_dir, name)
      FileUtils.mkdir_p(File.dirname(path))
      File.binwrite(path, entry.read)
    end
  end
end

Instance Method Details

#embedding(prefix, input) ⇒ Object

Raises:

  • (ArgumentError)


57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# File 'lib/kiribi/multilingual_e5_small.rb', line 57

def embedding(prefix, input)
  prefix = prefix.to_s
  raise ArgumentError, "prefix must be :query or :passage" unless %w[query passage].include?(prefix)

  encoded = tokenizer.encode("#{prefix}: #{input}")
  batch = {
    input_ids: [encoded.ids],
    attention_mask: [encoded.attention_mask],
    token_type_ids: [[0] * encoded.ids.length]
  }
  outputs = onnx_model.predict(batch)
  last_hidden = outputs["last_hidden_state"][0]
  attentions = encoded.attention_mask

  output_matrix = last_hidden.filter.with_index { |_, i| attentions[i] == 1 }
  valid_tokens = attentions.sum
  output_matrix.transpose.map { it.sum / valid_tokens }
end

#embedding_passage(input) ⇒ Object



53
54
55
# File 'lib/kiribi/multilingual_e5_small.rb', line 53

def embedding_passage(input)
  embedding(:passage, input)
end

#embedding_query(input) ⇒ Object



49
50
51
# File 'lib/kiribi/multilingual_e5_small.rb', line 49

def embedding_query(input)
  embedding(:query, input)
end