Module: Ignis::Collective::NvArrayAdapter
- Defined in:
- lib/nvruby/collective/nvarray_adapter.rb
Overview
Adapter for seamless NvArray integration with collective operations Handles dtype detection, shape validation, and buffer extraction
Constant Summary collapse
- DTYPE_SIZES =
Supported dtypes and their byte sizes
{ float32: 4, float64: 8, float16: 2, bfloat16: 2, int32: 4, int64: 8, int16: 2, int8: 1, uint8: 1, uint32: 4, uint64: 8 }.freeze
- DTYPE_CUDA_CODES =
CUDA dtype codes for kernel dispatch
{ float32: 0, float64: 1, float16: 2, bfloat16: 3, int32: 4, int64: 5, int16: 6, int8: 7, uint8: 8, uint32: 9, uint64: 10 }.freeze
Class Method Summary collapse
-
.broadcast_compatible?(src_tensor, dst_tensors) ⇒ Boolean
Validate shape compatibility for broadcast operations.
-
.buffer_info(tensors) ⇒ Hash
Create buffer info for collective ops.
-
.common_dtype(tensors) ⇒ Symbol
Get common dtype from tensor array.
-
.common_element_count(tensors) ⇒ Integer
Get element count from tensors (must match).
-
.dtype_cuda_code(dtype) ⇒ Integer
Get CUDA type code for kernel dispatch.
-
.dtype_size(dtype) ⇒ Integer
Get byte size for a dtype.
-
.extract_pointers(tensors) ⇒ Array<FFI::Pointer>
Extract device pointers from tensors for raw operations.
-
.normalize(inputs, expected_count: nil) ⇒ Array<NvArray>
Normalize input to array of NvArrays with validated properties.
-
.total_byte_size(tensors) ⇒ Integer
Calculate total byte size for tensors.
Class Method Details
.broadcast_compatible?(src_tensor, dst_tensors) ⇒ Boolean
Validate shape compatibility for broadcast operations
149 150 151 152 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 149 def broadcast_compatible?(src_tensor, dst_tensors) src_shape = extract_shape(src_tensor) dst_tensors.all? { |t| extract_shape(t) == src_shape } end |
.buffer_info(tensors) ⇒ Hash
Create buffer info for collective ops
158 159 160 161 162 163 164 165 166 167 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 158 def buffer_info(tensors) { pointers: extract_pointers(tensors), dtype: common_dtype(tensors), dtype_code: dtype_cuda_code(common_dtype(tensors)), element_count: common_element_count(tensors), byte_size: common_element_count(tensors) * dtype_size(common_dtype(tensors)), tensor_count: tensors.size } end |
.common_dtype(tensors) ⇒ Symbol
Get common dtype from tensor array
89 90 91 92 93 94 95 96 97 98 99 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 89 def common_dtype(tensors) return :float32 if tensors.empty? dtypes = tensors.map { |t| extract_dtype(t) }.uniq if dtypes.size > 1 raise ArgumentError, "All tensors must have same dtype, got: #{dtypes.join(', ')}" end dtypes.first end |
.common_element_count(tensors) ⇒ Integer
Get element count from tensors (must match)
106 107 108 109 110 111 112 113 114 115 116 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 106 def common_element_count(tensors) return 0 if tensors.empty? counts = tensors.map { |t| extract_element_count(t) }.uniq if counts.size > 1 raise ArgumentError, "All tensors must have same element count, got: #{counts.join(', ')}" end counts.first end |
.dtype_cuda_code(dtype) ⇒ Integer
Get CUDA type code for kernel dispatch
130 131 132 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 130 def dtype_cuda_code(dtype) DTYPE_CUDA_CODES[dtype] || 0 end |
.dtype_size(dtype) ⇒ Integer
Get byte size for a dtype
122 123 124 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 122 def dtype_size(dtype) DTYPE_SIZES[dtype] || raise(ArgumentError, "Unknown dtype: #{dtype}") end |
.extract_pointers(tensors) ⇒ Array<FFI::Pointer>
Extract device pointers from tensors for raw operations
66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 66 def extract_pointers(tensors) tensors.map do |t| if t.respond_to?(:data_ptr) t.data_ptr elsif t.respond_to?(:device_ptr) t.device_ptr elsif t.respond_to?(:pointer) t.pointer elsif t.is_a?(Fiddle::Pointer) t elsif defined?(FFI::Pointer) && t.is_a?(FFI::Pointer) t else raise ArgumentError, "Cannot extract pointer from #{t.class}" end end end |
.normalize(inputs, expected_count: nil) ⇒ Array<NvArray>
Normalize input to array of NvArrays with validated properties
50 51 52 53 54 55 56 57 58 59 60 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 50 def normalize(inputs, expected_count: nil) tensors = wrap_array(inputs) if expected_count && tensors.size != expected_count raise ArgumentError, "Expected #{expected_count} tensors, got #{tensors.size}" end validate_tensors!(tensors) tensors end |
.total_byte_size(tensors) ⇒ Integer
Calculate total byte size for tensors
138 139 140 141 142 |
# File 'lib/nvruby/collective/nvarray_adapter.rb', line 138 def total_byte_size(tensors) tensors.sum do |t| extract_element_count(t) * dtype_size(extract_dtype(t)) end end |