Module: Torch::Hub
- Defined in:
- lib/torch/hub.rb
Class Method Summary collapse
- .download_url_to_file(url, dst) ⇒ Object
- .list(github, force_reload: false) ⇒ Object
- .load_state_dict_from_url(url, model_dir: nil) ⇒ Object
Class Method Details
.download_url_to_file(url, dst) ⇒ Object
8 9 10 11 12 13 14 15 16 17 18 19 20 |
# File 'lib/torch/hub.rb', line 8 def download_url_to_file(url, dst) require "open-uri" uri = URI.parse(url) raise "Invalid URL" unless uri.is_a?(URI::HTTP) # includes https puts "Downloading #{url}..." uri.open(max_redirects: 10) do |download| # TODO move file when possible IO.copy_stream(download, dst.to_str) end nil end |
.list(github, force_reload: false) ⇒ Object
4 5 6 |
# File 'lib/torch/hub.rb', line 4 def list(github, force_reload: false) raise NotImplementedYet end |
.load_state_dict_from_url(url, model_dir: nil) ⇒ Object
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 |
# File 'lib/torch/hub.rb', line 22 def load_state_dict_from_url(url, model_dir: nil) unless model_dir torch_home = ENV["TORCH_HOME"] || "#{ENV["XDG_CACHE_HOME"] || "#{ENV["HOME"]}/.cache"}/torch" model_dir = File.join(torch_home, "checkpoints") end FileUtils.mkdir_p(model_dir) parts = URI(url) filename = File.basename(parts.path) cached_file = File.join(model_dir, filename) unless File.exist?(cached_file) # TODO support hash_prefix download_url_to_file(url, cached_file) end Torch.load(cached_file) end |