Module: ToyVit

Defined in:
lib/toy/models/toy_vit.rb

Class Method Summary collapse

Class Method Details

.patch_embed(sess, kernel, image, patch_size) ⇒ Object

Returns the patch-embedding tensor handle. Caller is responsible for set_param-ing the kernel (training) + downstream graph build.

kernel: tnn 4D F32 persistent, ne=[KW, KH, IC, d_model] image: tnn 3D F32 (or 4D if N>1), ne=[W, H, IC, (N)] patch: kernel size = stride (no overlap, no padding).

Currently assumes N=1 (single image). For batch>1 the cont_2d would need to be cont_3d(OC, OW*OH, N); follow-up.



63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# File 'lib/toy/models/toy_vit.rb', line 63

def self.patch_embed(sess, kernel, image, patch_size)
  conv = TinyNN.tnn_conv_2d(sess, kernel, image,
                              patch_size, patch_size,   # stride
                              0, 0,                       # padding
                              1, 1)                       # dilation
  # ne=[OW, OH, OC, N] → ggml_permute moves source-axis-i to
  # result-axis-axis_i. To get result ne=[OC, OW, OH, N] (OC at 0,
  # OW at 1, OH at 2, N at 3) we map:
  #   source axis 0 (OW) → result axis 1
  #   source axis 1 (OH) → result axis 2
  #   source axis 2 (OC) → result axis 0
  #   source axis 3 (N)  → result axis 3
  # → permute(1, 2, 0, 3).
  perm = TinyNN.tnn_permute(sess, conv, 1, 2, 0, 3)
  # Flatten the OW*OH spatial axis into a single "patch" axis.
  ow = TinyNN.tnn_tensor_ne0(conv)
  oh = TinyNN.tnn_tensor_ne1(conv)
  oc = TinyNN.tnn_tensor_ne2(conv)
  TinyNN.tnn_cont_2d(sess, perm, oc, ow * oh)
end