ggml_flash_attn_ext method

Pointer<ggml_tensor> ggml_flash_attn_ext(
  1. Pointer<ggml_context> ctx,
  2. Pointer<ggml_tensor> q,
  3. Pointer<ggml_tensor> k,
  4. Pointer<ggml_tensor> v,
  5. Pointer<ggml_tensor> mask,
  6. double scale,
  7. double max_bias,
  8. double logit_softcap,
)

q: n_embd, n_batch, n_head, 1 k: n_embd, n_kv, n_head_kv, 1 v: n_embd, n_kv, n_head_kv, 1 !! not transposed !! mask: n_kv, n_batch_pad, 1, 1 !! n_batch_pad = GGML_PAD(n_batch, GGML_KQ_MASK_PAD) !! res: n_embd, n_head, n_batch, 1 !! permuted !!

Implementation

ffi.Pointer<ggml_tensor> ggml_flash_attn_ext(
  ffi.Pointer<ggml_context> ctx,
  ffi.Pointer<ggml_tensor> q,
  ffi.Pointer<ggml_tensor> k,
  ffi.Pointer<ggml_tensor> v,
  ffi.Pointer<ggml_tensor> mask,
  double scale,
  double max_bias,
  double logit_softcap,
) {
  return _ggml_flash_attn_ext(
    ctx,
    q,
    k,
    v,
    mask,
    scale,
    max_bias,
    logit_softcap,
  );
}