Class: Ignis::MathDx::FftKernel

Inherits:
Object
  • Object
show all
Defined in:
lib/nvruby/mathdx/fft_kernel.rb

Overview

Device-side FFT kernel using cuFFTDx patterns Generates and compiles CUDA C++ code for thread block FFT operations

cuFFTDx enables embedding FFT operations inside CUDA kernels, allowing fusion with other operations to reduce memory bandwidth.

Examples:

Basic FFT kernel

kernel = FftKernel.new(size: 64, dtype: :complex64, direction: :forward)
kernel.compile!
kernel.execute(input, output)

Constant Summary collapse

SUPPORTED_SIZES =

Supported FFT sizes (powers of 2)

[16, 32, 64, 128, 256, 512, 1024].freeze

Instance Attribute Summary collapse

Instance Method Summary collapse

Constructor Details

#initialize(size:, dtype: :complex64, direction: :forward, elements_per_thread: 8) ⇒ FftKernel

Initialize FFT kernel configuration

Parameters:

  • size (Integer)

    FFT size (power of 2, <= 1024)

  • dtype (Symbol) (defaults to: :complex64)

    Data type (:complex64, :complex128)

  • direction (Symbol) (defaults to: :forward)

    Direction (:forward, :inverse)

  • elements_per_thread (Integer) (defaults to: 8)

    Elements per thread (default: 8)



39
40
41
42
43
44
45
46
47
48
49
50
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 39

def initialize(size:, dtype: :complex64, direction: :forward, elements_per_thread: 8)
  validate_size!(size)
  validate_dtype!(dtype)
  validate_direction!(direction)

  @size = size
  @dtype = dtype
  @direction = direction
  @elements_per_thread = elements_per_thread
  @compiled = false
  @kernel = nil
end

Instance Attribute Details

#compiledBoolean (readonly)

Returns Whether kernel is compiled.

Returns:

  • (Boolean)

    Whether kernel is compiled



29
30
31
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 29

def compiled
  @compiled
end

#directionSymbol (readonly)

Returns FFT direction.

Returns:

  • (Symbol)

    FFT direction



23
24
25
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 23

def direction
  @direction
end

#dtypeSymbol (readonly)

Returns Data type.

Returns:

  • (Symbol)

    Data type



20
21
22
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 20

def dtype
  @dtype
end

#elements_per_threadInteger (readonly)

Returns Elements per thread.

Returns:

  • (Integer)

    Elements per thread



26
27
28
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 26

def elements_per_thread
  @elements_per_thread
end

#sizeInteger (readonly)

Returns FFT size.

Returns:

  • (Integer)

    FFT size



17
18
19
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 17

def size
  @size
end

Instance Method Details

#compile!(device_id: 0) ⇒ self

Compile the FFT kernel

Parameters:

  • device_id (Integer) (defaults to: 0)

    Target GPU device

Returns:

  • (self)


55
56
57
58
59
60
61
62
63
64
65
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 55

def compile!(device_id: 0)
  source = generate_source
  @kernel = Ignis::JIT::Compiler.compile(
    source,
    "cufftdx_fft",
    device_id: device_id,
    options: nvrtc_options
  )
  @compiled = true
  self
end

#destroy!void

This method returns an undefined value.

Release kernel resources



111
112
113
114
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 111

def destroy!
  @kernel = nil
  @compiled = false
end

#execute(input, output: nil, batch: 1, stream: nil) ⇒ NvArray

Execute the FFT kernel

Parameters:

  • input (NvArray)

    Input array (complex)

  • output (NvArray, nil) (defaults to: nil)

    Output array (created if nil)

  • batch (Integer) (defaults to: 1)

    Number of batched FFTs

  • stream (CUDA::Stream, nil) (defaults to: nil)

    CUDA stream

Returns:

Raises:



73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# File 'lib/nvruby/mathdx/fft_kernel.rb', line 73

def execute(input, output: nil, batch: 1, stream: nil)
  raise StateError, "Kernel not compiled. Call compile! first." unless @compiled

  validate_execution_input!(input)

  # Ensure input is on device
  input_dev = input.on_device? ? input : input.to_device

  # Create output if needed
  output_dev = if output
                 output.on_device? ? output : output.to_device
               else
                 NvArray.zeros(input.shape, dtype: @dtype, device: input_dev.device_index).to_device
               end

  # Calculate grid dimensions
  threads_per_fft = @size / @elements_per_thread
  blocks = batch

  # Launch kernel
  @kernel.launch(
    grid: [blocks],
    block: [threads_per_fft],
    shared_memory: shared_memory_size,
    args: [
      input_dev.device_ptr,
      output_dev.device_ptr,
      @size,
      batch
    ],
    stream: stream
  )

  output_dev
end