Module: Ignis::Solver::LU
- Defined in:
- lib/nvruby/solver/lu.rb
Overview
LU Decomposition operations using cuSOLVER Computes P*A = L*U factorization for solving linear systems
Class Method Summary collapse
-
.getrf(matrix, overwrite: false) ⇒ Hash
Compute LU factorization of a matrix.
-
.getrs(a, b, pivot: nil, trans: :none) ⇒ NvArray
Solve linear system Ax = b using LU factorization.
-
.solve(a, b) ⇒ NvArray
Solve linear system Ax = b (convenience method).
Class Method Details
.getrf(matrix, overwrite: false) ⇒ Hash
Compute LU factorization of a matrix
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
# File 'lib/nvruby/solver/lu.rb', line 16 def getrf(matrix, overwrite: false) CuSolverBindings.ensure_loaded! validate_matrix!(matrix) m, n = matrix.shape lda = m # Copy matrix if not overwriting work_matrix = overwrite ? matrix : matrix.dup # Get workspace size lwork_ptr = FFI::MemoryPointer.new(:int) get_buffer_size(matrix.dtype, m, n, work_matrix.device_ffi_ptr, lda, lwork_ptr) lwork = lwork_ptr.read_int # Allocate workspace and pivot array workspace = CUDA::Memory.new(lwork * dtype_size(matrix.dtype)) pivot = CUDA::Memory.new(([m, n].min) * 4) # int32 pivot indices info = CUDA::Memory.new(4) # int32 info # Perform LU factorization (FFI cuSOLVER needs FFI pointers, not Fiddle) perform_getrf(matrix.dtype, m, n, work_matrix.device_ffi_ptr, lda, workspace.ffi_ptr, pivot.ffi_ptr, info.ffi_ptr) # Read info to check for errors info_value = read_device_int(info) if info_value < 0 raise CuSolverError.new("LU factorization: parameter #{-info_value} had an illegal value", cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE) elsif info_value > 0 Ignis.logger.warn("LU factorization: U(#{info_value},#{info_value}) is exactly zero. " \ "The factorization has been completed, but U is singular.") end # Synchronize to ensure completion CUDA::RuntimeAPI.cudaDeviceSynchronize { lu: work_matrix, pivot: pivot, pivot_size: [m, n].min, info: info_value } ensure workspace&.free! if defined?(workspace) && workspace info&.free! if defined?(info) && info && info_value end |
.getrs(a, b, pivot: nil, trans: :none) ⇒ NvArray
Solve linear system Ax = b using LU factorization
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
# File 'lib/nvruby/solver/lu.rb', line 71 def getrs(a, b, pivot: nil, trans: :none) CuSolverBindings.ensure_loaded! validate_matrix!(a) validate_matrix!(b) n = a.shape[0] raise ArgumentError, "Matrix A must be square" unless a.shape[0] == a.shape[1] raise ArgumentError, "Matrix dimensions mismatch" unless b.shape[0] == n nrhs = b.shape.length > 1 ? b.shape[1] : 1 lda = n ldb = n # If no pivot provided, compute LU factorization first if pivot.nil? result = getrf(a) lu_matrix = result[:lu] pivot = result[:pivot] else lu_matrix = a end # Copy b to output x = b.dup # Allocate info info = CUDA::Memory.new(4) # Map transpose option trans_op = case trans when :none then CuSolverBindings::CUBLAS_OP_N when :transpose then CuSolverBindings::CUBLAS_OP_T when :conjugate then CuSolverBindings::CUBLAS_OP_C else CuSolverBindings::CUBLAS_OP_N end # Perform solve (FFI cuSOLVER needs FFI pointers, not Fiddle) pivot_ptr = pivot.respond_to?(:ffi_ptr) ? pivot.ffi_ptr : pivot perform_getrs(a.dtype, trans_op, n, nrhs, lu_matrix.device_ffi_ptr, lda, pivot_ptr, x.device_ffi_ptr, ldb, info.ffi_ptr) # Check info info_value = read_device_int(info) if info_value < 0 raise CuSolverError.new("LU solve: parameter #{-info_value} had an illegal value", cusolver_code: CuSolverBindings::CUSOLVER_STATUS_INVALID_VALUE) end CUDA::RuntimeAPI.cudaDeviceSynchronize x ensure info&.free! if defined?(info) && info end |
.solve(a, b) ⇒ NvArray
Solve linear system Ax = b (convenience method)
131 132 133 |
# File 'lib/nvruby/solver/lu.rb', line 131 def solve(a, b) getrs(a, b) end |