BATCHED – KokkosKernels batched host-level interfaces

BatchedGemm

template<typename ArgTransA, typename ArgTransB, typename ArgBatchSzDim, typename BatchedGemmHandleType, typename ScalarType, typename AViewType, typename BViewType, typename CViewType>
inline int KokkosBatched::BatchedGemm(BatchedGemmHandleType *const handle, const ScalarType alpha, const AViewType &A, const BViewType &B, const ScalarType beta, const CViewType &C)

Non-blocking solve of general matrix multiply on a batch of uniform matrices.

Note: If a TPL is selected, this interface follows the blocking behavior (either blocking or non-blocking) of the TPL vendor’s API.

Note: To leverage SIMD instructions, 4-rank views must be selected via the template parameters documented below.

   C = alpha * op(A) * op(B) + beta * C
Usage Example: BatchedGemm<ArgTransA, ArgTransB, ArgBatchSzDim>(handle, alpha, A, B, beta, C);

Template Parameters:
  • ArgTransA – Specifies what op does to A:

                       Trans::NoTranspose   for non-transpose
                       Trans::Transpose     for transpose
                       Trans::ConjTranspose for conjugate transpose
    

  • ArgTransB – Specifies what op does to B:

                       Trans::NoTranspose   for non-transpose
                       Trans::Transpose     for transpose
                       Trans::ConjTranspose for conjugate transpose
    

  • ArgBatchSzDim – Specifies where the batch dimension is allocated in

                       AViewType, BViewType, and CViewType:
                       BatchLayout::Left  Batch dimension is leftmost
                       BatchLayout::Right Batch dimension is rightmost
    

  • ScalarType – Specifies the scalar type of alpha and beta

  • AViewType – Input matrix, as either a 3-rank Kokkos::View or a 4-rank Kokkos::View for SIMD operations.

  • BViewType – Input matrix, as either a 3-rank Kokkos::View or a 4-rank Kokkos::View for SIMD operations.

  • CViewType – Input(RHS)/Output(LHS) matrix, as either a 3-rank Kokkos::View or a 4-rank Kokkos::View for SIMD operations.

Parameters:
  • handle – [in] A handle which specifies how to invoke the batched gemm. See struct BatchedGemmHandle for details.

  • alpha – [in] Input coefficient used for multiplication with A

  • A – [in] Input matrix, as a 3-rank Kokkos::View

                       If ArgBatchSzDim == "BatchLayout::Right", matrix A is MxKxB
                       If ArgBatchSzDim == "BatchLayout::Left",  matrix A is BxMxK
    

  • B – [in] Input matrix, as a 3-rank Kokkos::View

                       If ArgBatchSzDim == "BatchLayout::Right", matrix B is KxNxB
                       If ArgBatchSzDim == "BatchLayout::Left",  matrix B is BxKxN
    

  • beta – [in] Input coefficient used for multiplication with C

  • C – [in/out] Input/Output matrix, as a 3-rank Kokkos::View

                       If ArgBatchSzDim == "BatchLayout::Right", matrix C is MxNxB
                       If ArgBatchSzDim == "BatchLayout::Left",  matrix C is BxMxN
    

Returns:

0 upon success, non-zero otherwise

class BatchedGemmHandle : public KokkosBatched::BatchedKernelHandle

Handle for selecting runtime behavior of the BatchedGemm interface.

Param kernelAlgoType:

Specifies which algorithm to use for invocation (default, SQUARE).

                   Specifies whether to select optimal invocations based on inputs and
                   heuristics:
                     SQUARE select invocations based on square matrix heuristics where M=N
                     TALL   select invocations based on tall   matrix heuristics where M>N
                     WIDE   select invocations based on wide   matrix heuristics where M<N

                   Specifies which cmake-enabled TPL algorithm to invoke:
                     ARMPL    Invoke the ArmPL TPL interface  (Currently UNSUPPORTED)
                     MKL      Invoke the MKL TPL interface    (Currently UNSUPPORTED)
                     CUBLAS   Invoke the CuBLAS TPL interface (Currently UNSUPPORTED)
                     MAGMA    Invoke the Magma TPL interface  (Currently UNSUPPORTED)
                   Note: Requires that input views for A, B, and C reside on either host
                         or device depending on the TPL selected.
                   Note: If the user selects a TPL, an error will be thrown if:
                           1. The TPL is not enabled via cmake
                           2. The input views do not reside on the host/device as needed

                   Specifies which kokkos-kernels (KK) algorithm to invoke:
                     KK_SERIAL       Invoke SerialGemm     via RangePolicy(BatchSz)
                     KK_TEAM         Invoke TeamGemm       via TeamPolicy(BatchSz)
                     KK_TEAMVECTOR   Invoke TeamVectorGemm via TeamPolicy(BatchSz)
                     KK_SERIALSIMD   Invoke SerialGemm     via TeamPolicy(BatchSz)
                     KK_TEAMSIMD     Invoke TeamGemm       via TeamPolicy(BatchSz)
                     KK_SERIAL_RANK0 Invoke SerialGemm     via RangePolicy(BatchSz*N*M)
                                     Each thread computes one element of C.
                     KK_SERIAL_SHMEM Invoke SerialGemm     via TeamPolicy(BatchSz)
                                     Copies A and B to shared memory before GEMM.
                                     Each vector lane solves one element of C via SerialGemm.
                     KK_DBLBUF       Solve GEMM            via TeamPolicy(BatchSz*TILES)
                                     Uses a tuned functor with tiling and double buffering
                                     via shared memory and register buffers.
                                     KK_DBLBUF generally performs better on GPUs when M, N >= 24.

Param teamSz:

Specifies the team size that will affect any KK algorithm which uses TeamPolicy (default, Kokkos::AUTO). Note: Only applied if useAlgo_type == KK_*

Param vecLen:

Specifies the vector length that will affect any KK algorithm which uses TeamPolicy and Kokkos::ThreadVectorRange or Kokkos::TeamVectorRange (default, Kokkos::AUTO). Note: Only applied if useAlgo_type == KK_*