Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add conv_transpose_1d_gemm #940

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions include/ggml.h
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,7 @@ extern "C" {
GGML_OP_CONV_TRANSPOSE_1D,
GGML_OP_IM2COL,
GGML_OP_IM2COL_BACK,
GGML_OP_COL2IM,
GGML_OP_CONV_TRANSPOSE_2D,
GGML_OP_POOL_1D,
GGML_OP_POOL_2D,
Expand Down Expand Up @@ -1614,6 +1615,18 @@ extern "C" {
int d1, // dilation dimension 1
bool is_2D);

GGML_API struct ggml_tensor * ggml_col2im(
struct ggml_context * ctx,
struct ggml_tensor * a,
int s0,
int s1,
int p0,
int p1,
int d0,
int d1,
int64_t KH,
int64_t IH);

GGML_API struct ggml_tensor * ggml_conv_depthwise_2d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
Expand Down Expand Up @@ -1650,6 +1663,14 @@ extern "C" {
int p0, // padding
int d0); // dilation

GGML_API struct ggml_tensor * ggml_conv_transpose_1d_gemm(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
struct ggml_tensor * b, // data
int s0, // stride
int p0, // padding
int d0); // dilation

GGML_API struct ggml_tensor * ggml_conv_2d(
struct ggml_context * ctx,
struct ggml_tensor * a, // convolution kernel
Expand Down
5 changes: 5 additions & 0 deletions src/ggml-cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
#include "ggml-cuda/argsort.cuh"
#include "ggml-cuda/binbcast.cuh"
#include "ggml-cuda/clamp.cuh"
#include "ggml-cuda/col2im.cuh"
#include "ggml-cuda/concat.cuh"
#include "ggml-cuda/conv-transpose-1d.cuh"
#include "ggml-cuda/convert.cuh"
Expand Down Expand Up @@ -2306,6 +2307,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg
case GGML_OP_CONV_TRANSPOSE_1D:
ggml_cuda_op_conv_transpose_1d(ctx,dst);
break;
case GGML_OP_COL2IM:
ggml_cuda_op_col2im(ctx, dst);
break;
case GGML_OP_POOL_2D:
ggml_cuda_op_pool2d(ctx, dst);
break;
Expand Down Expand Up @@ -2869,6 +2873,7 @@ GGML_CALL static bool ggml_backend_cuda_supports_op(ggml_backend_t backend, cons
}
return false;
} break;
case GGML_OP_COL2IM:
case GGML_OP_NONE:
case GGML_OP_RESHAPE:
case GGML_OP_VIEW:
Expand Down
75 changes: 75 additions & 0 deletions src/ggml-cuda/col2im.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
#include "col2im.cuh"

static __global__ void col2im_kernel(
const float * src, float* dst,
const int64_t IW, const int64_t KW, const int64_t OC, const int64_t N, int64_t OW,
const int64_t ioc_offs, const int64_t ikw_offs, const int64_t in_offs,
const int32_t s0, const int32_t p0, const int32_t d0) {
const auto batch_size = OC * OW;
for (int64_t i = blockIdx.x * blockDim.x + threadIdx.x;
i < batch_size * N;
i += blockDim.x * gridDim.x) {
const auto in = i / batch_size;
const auto iow = i % batch_size; // = ioc * OW + ikw*d0 - p0 + iiws*s0
const auto kwiw = iow % OW + p0; // = ikw*d0 + iiw*s0
const auto ioc = iow / OW;
const auto max_kernel = (KW - 1) * d0;
// iter iiw only over values that have
// a chance of being valid
// i.e. values that will satisfy:
// 0 < ikw*d0 - p0 + iiw*s0 < OW
const auto iiws = ::max(0L, (kwiw - max_kernel + s0 - 1) / s0);
const auto iiwe = ::min(IW, kwiw / s0 + 1);

const float *const input = src + in * in_offs;
float val = 0;
for (auto iiw = iiws; iiw < iiwe; ++iiw) {
const auto ikw_d = (kwiw - iiw * s0);
if (ikw_d % d0 == 0) {
const auto ikw = ikw_d / d0;
const auto input_index = ioc * ioc_offs + ikw * ikw_offs + iiw;
val += input[input_index];
}
}
dst[in * OC * OW + iow] = val;
}
}

void ggml_cuda_op_col2im(ggml_backend_cuda_context & ctx, ggml_tensor * dst) {
const ggml_tensor * src0 = dst->src[0];
const auto * src0_d = (const float *)src0->data;
auto * dst_d = (float *)dst->data;
auto stream = ctx.stream();

GGML_ASSERT(src0->type == GGML_TYPE_F32);
assert(ggml_is_contiguous(dst));
assert(dst->type == GGML_TYPE_F32);

GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));

const auto s0 = dst->op_params[0];
const auto p0 = dst->op_params[2];
const auto d0 = dst->op_params[4];

assert(s0 >= 1);
assert(p0 >= 0);
assert(d0 >= 1);

const auto IW = src0->ne[0];
const auto KW = src0->ne[1];
const auto OC = src0->ne[2];
const auto N = src0->ne[3];
const auto OW = dst->ne[0];

const auto ioc_offs = src0->nb[2] / src0->nb[0];
const auto ikw_offs = src0->nb[1] / src0->nb[0];
const auto in_offs = src0->nb[3] / src0->nb[0];

const int parallel_elements = N * OC * OW;
const int num_blocks = (parallel_elements + CUDA_COL2IM_BLOCK_SIZE - 1) / CUDA_COL2IM_BLOCK_SIZE;

col2im_kernel<<<num_blocks, CUDA_COL2IM_BLOCK_SIZE, 0, stream>>>(src0_d, dst_d,
IW, KW, OC, N, OW,
ioc_offs, ikw_offs, in_offs,
s0, p0, d0);
}
5 changes: 5 additions & 0 deletions src/ggml-cuda/col2im.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
#include "common.cuh"

#define CUDA_COL2IM_BLOCK_SIZE 256

void ggml_cuda_op_col2im(ggml_backend_cuda_context & ctx, ggml_tensor * dst);
11 changes: 11 additions & 0 deletions src/ggml-sycl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3845,6 +3845,13 @@ static void ggml_sycl_im2col(ggml_backend_sycl_context & ctx, const ggml_tensor
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_im2col);
}

static void ggml_sycl_col2im(ggml_backend_sycl_context &ctx,
const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst) {
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_col2im);
}

static void ggml_sycl_sum_rows(ggml_backend_sycl_context & ctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst) {
GGML_ASSERT(ggml_is_contiguous(src0));
ggml_sycl_op_flatten(ctx, src0, src1, dst, ggml_sycl_op_sum_rows);
Expand Down Expand Up @@ -4010,6 +4017,9 @@ bool ggml_sycl_compute_forward(ggml_backend_sycl_context & ctx, struct ggml_tens
case GGML_OP_IM2COL:
func = ggml_sycl_im2col;
break;
case GGML_OP_COL2IM:
func = ggml_sycl_col2im;
break;
case GGML_OP_POOL_2D:
func = ggml_sycl_pool2d;
break;
Expand Down Expand Up @@ -5131,6 +5141,7 @@ GGML_CALL static bool ggml_backend_sycl_supports_op(ggml_backend_t backend, cons
case GGML_OP_ROPE:
return ggml_is_contiguous(op->src[0]);
case GGML_OP_IM2COL:
case GGML_OP_COL2IM:
case GGML_OP_POOL_2D:
case GGML_OP_SUM_ROWS:
case GGML_OP_ARGSORT:
Expand Down
1 change: 1 addition & 0 deletions src/ggml-sycl/backend.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,6 @@
#include "softmax.hpp"
#include "tsembd.hpp"
#include "im2col.hpp"
#include "col2im.hpp"

#endif // GGML_SYCL_BACKEND_HPP
112 changes: 112 additions & 0 deletions src/ggml-sycl/col2im.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
#include "col2im.hpp"

#include "common.hpp"

static void col2im_kernel(const float *src,
float *dst,
const int64_t IW,
const int64_t KW,
const int64_t OC,
const int64_t N,
int64_t OW,
const int64_t ioc_offs,
const int64_t ikw_offs,
const int64_t in_offs,
const int32_t s0,
const int32_t p0,
const int32_t d0,
const sycl::nd_item<1> &item) {
const int64_t global_id = item.get_global_linear_id();
const int64_t batch_size = OC * OW;
const int64_t osize = N * batch_size;

for (int64_t i = global_id; i < osize; i += batch_size) {
const auto in = i / batch_size;
const auto iow = i % batch_size; // = ioc * OW + ikw*d0 - p0 + iiws*s0
const auto kwiw = iow % OW + p0; // = ikw*d0 + iiw*s0
const auto ioc = iow / OW;
const auto max_kernel = (KW - 1) * d0;
// iter iiw only over values that have
// a chance of being valid
// i.e. values that will satisfy:
// 0 < ikw*d0 - p0 + iiw*s0 < OW
const auto iiws = std::max(0L, (kwiw - max_kernel + s0 - 1) / s0);
const auto iiwe = std::min(IW, kwiw / s0 + 1);

const float *const input = src + in * in_offs;
float val = 0;
for (auto iiw = iiws; iiw < iiwe; ++iiw) {
const auto ikw_d = (kwiw - iiw * s0);
if (ikw_d % d0 == 0) {
const auto ikw = ikw_d / d0;
const auto input_index = ioc * ioc_offs + ikw * ikw_offs + iiw;
val += input[input_index];
}
}
dst[in * OC * OW + iow] = val;
}
}

static void col2im_sycl(const float *src,
float *dst,
const int64_t IW,
const int64_t KW,
const int64_t OC,
const int64_t N,
int64_t OW,
const int64_t ioc_offs,
const int64_t ikw_offs,
const int64_t in_offs,
const int32_t s0,
const int32_t p0,
const int32_t d0,
queue_ptr stream) {
const int64_t batch_size = OC * OW;
const size_t num_blocks =
(batch_size + SYCL_COL2IM_BLOCK_SIZE - 1) / SYCL_COL2IM_BLOCK_SIZE;
stream->parallel_for(
sycl::nd_range<1>(num_blocks * SYCL_COL2IM_BLOCK_SIZE, SYCL_COL2IM_BLOCK_SIZE),
[=](sycl::nd_item<1> item) {
col2im_kernel(
src, dst, IW, KW, OC, N, OW, ioc_offs, ikw_offs, in_offs, s0, p0, d0, item);
});
}

void ggml_sycl_op_col2im(ggml_backend_sycl_context &,
const ggml_tensor *src0,
const ggml_tensor *,
ggml_tensor *dst,
const float *src0_dd,
const float *,
float *dst_dd,
const queue_ptr &main_stream) {
GGML_ASSERT(src0->type == GGML_TYPE_F32);
assert(ggml_is_contiguous(dst));
assert(dst->type == GGML_TYPE_F32);

GGML_ASSERT(src0->nb[0] == ggml_type_size(src0->type));

const auto s0 = dst->op_params[0];
const auto p0 = dst->op_params[2];
const auto d0 = dst->op_params[4];

assert(s0 >= 1);
assert(p0 >= 0);
assert(d0 >= 1);

const auto IW = src0->ne[0];
const auto KW = src0->ne[1];
const auto OC = src0->ne[2];
const auto N = src0->ne[3];
const auto OW = dst->ne[0];

const auto ioc_offs = src0->nb[2] / src0->nb[0];
const auto ikw_offs = src0->nb[1] / src0->nb[0];
const auto in_offs = src0->nb[3] / src0->nb[0];

col2im_sycl(src0_dd, dst_dd,
IW, KW, OC, N, OW,
ioc_offs, ikw_offs, in_offs,
s0, p0, d0,
main_stream);
}
15 changes: 15 additions & 0 deletions src/ggml-sycl/col2im.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#ifndef GGML_SYCL_COl2IM_HPP
#define GGML_SYCL_COl2IM_HPP

#include "common.hpp"

void ggml_sycl_op_col2im(ggml_backend_sycl_context &ctx,
const ggml_tensor *src0,
const ggml_tensor *src1,
ggml_tensor *dst,
const float *src0_dd,
const float *src1_dd,
float *dst_dd,
const queue_ptr &main_stream);

#endif // GGML_SYCL_COl2IM_HPP
1 change: 1 addition & 0 deletions src/ggml-sycl/presets.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#define SYCL_PAD_BLOCK_SIZE 256
#define SYCL_ACC_BLOCK_SIZE 256
#define SYCL_IM2COL_BLOCK_SIZE 256
#define SYCL_COL2IM_BLOCK_SIZE 256
#define SYCL_POOL2D_BLOCK_SIZE 256
#define SYCL_CONV_TRANPOSE_1D_BLOCK_SIZE 256
#define SYCL_TIMESTEP_EMBEDDING_BLOCK_SIZE 256
Expand Down
Loading
Loading