Module: Torch::Hub

Defined in:
lib/torch/hub.rb

Class Method Summary collapse

Class Method Details

.download_url_to_file(url, dst) ⇒ Object



8
9
10
11
12
13
14
15
16
17
18
19
20
# File 'lib/torch/hub.rb', line 8

def download_url_to_file(url, dst)
  require "open-uri"

  uri = URI.parse(url)
  raise "Invalid URL" unless uri.is_a?(URI::HTTP) # includes https

  puts "Downloading #{url}..."
  uri.open(max_redirects: 10) do |download|
    # TODO move file when possible
    IO.copy_stream(download, dst.to_str)
  end
  nil
end

.list(github, force_reload: false) ⇒ Object

Raises:



4
5
6
# File 'lib/torch/hub.rb', line 4

def list(github, force_reload: false)
  raise NotImplementedYet
end

.load_state_dict_from_url(url, model_dir: nil) ⇒ Object



22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# File 'lib/torch/hub.rb', line 22

def load_state_dict_from_url(url, model_dir: nil)
  unless model_dir
    torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch"
    model_dir = File.join(torch_home, "checkpoints")
  end

  FileUtils.mkdir_p(model_dir)

  parts = URI(url)
  filename = File.basename(parts.path)
  cached_file = File.join(model_dir, filename)
  unless File.exist?(cached_file)
    # TODO support hash_prefix
    download_url_to_file(url, cached_file)
  end

  Torch.load(cached_file)
end