Class: Kiribi::RuriV3Small

Inherits:
Object
  • Object
show all
Defined in:
lib/kiribi/ruri_v3_small.rb

Constant Summary collapse

FILES =
%w[model.onnx tokenizer.json].freeze
URL =
"https://github.com/matsudai/kiribi-externals/releases/download/sirasagi62%2Fruri-v3-30m-ONNX%2Fcdf9391/model.tar.gz"

Instance Attribute Summary collapse

Class Method Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(dest_dir) ⇒ RuriV3Small

Returns a new instance of RuriV3Small.



39
40
41
42
43
44
45
46
# File 'lib/kiribi/ruri_v3_small.rb', line 39

def initialize(dest_dir)
  FILES.each do |f|
    path = File.join(dest_dir, f)
    raise Kiribi::ModelNotDownloaded, %(ruri-v3-30m: #{f} missing. Run: Kiribi.download("ruri-v3-30m")) unless File.exist?(path)
  end
  @tokenizer = Tokenizers.from_file(File.join(dest_dir, "tokenizer.json"))
  @onnx_model = OnnxRuntime::Model.new(File.join(dest_dir, "model.onnx"))
end

Instance Attribute Details

#onnx_modelObject (readonly)

Returns the value of attribute onnx_model.



37
38
39
# File 'lib/kiribi/ruri_v3_small.rb', line 37

def onnx_model
  @onnx_model
end

#tokenizerObject (readonly)

Returns the value of attribute tokenizer.



37
38
39
# File 'lib/kiribi/ruri_v3_small.rb', line 37

def tokenizer
  @tokenizer
end

Class Method Details

.download(dest_dir, force: false) ⇒ Object



16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
# File 'lib/kiribi/ruri_v3_small.rb', line 16

def self.download(dest_dir, force: false)
  return if !force && FILES.all? { |f| File.exist?(File.join(dest_dir, f)) }
  FileUtils.rm_rf(dest_dir) if force
  FileUtils.mkdir_p(dest_dir)

  io = StringIO.new
  Kiribi.http_get(URL) { |chunk| io.write(chunk) }
  io.rewind

  Gem::Package::TarReader.new(Zlib::GzipReader.new(io)) do |tar|
    tar.each do |entry|
      next unless entry.file?
      name = Pathname(entry.full_name).each_filename.to_a[1..].join("/")
      next unless FILES.include?(name)
      path = File.join(dest_dir, name)
      FileUtils.mkdir_p(File.dirname(path))
      File.binwrite(path, entry.read)
    end
  end
end

Instance Method Details

#embedding(text) ⇒ Object



48
49
50
51
52
53
54
55
56
# File 'lib/kiribi/ruri_v3_small.rb', line 48

def embedding(text)
  encoded = tokenizer.encode(text)
  batch = {
    input_ids: [encoded.ids],
    attention_mask: [encoded.attention_mask]
  }
  outputs = onnx_model.predict(batch)
  outputs["sentence_embedding"][0]
end

#embedding_normalized(text) ⇒ Object



58
59
60
61
62
# File 'lib/kiribi/ruri_v3_small.rb', line 58

def embedding_normalized(text)
  vec = embedding(text)
  norm = Math.sqrt(vec.sum { it * it })
  vec.map { it / norm }
end