Skip to content

Instantly share code, notes, and snippets.

@ResidentMario
Created January 31, 2021 01:08
Show Gist options
  • Save ResidentMario/1a4f6473828048990e26d12d58d7a227 to your computer and use it in GitHub Desktop.
Save ResidentMario/1a4f6473828048990e26d12d58d7a227 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Derived from https://discuss.pytorch.org/t/custom-convolution-layer/45979/5."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class Conv2d(nn.Module):\n",
" def __init__(\n",
" self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.kernel_size = kernel_size\n",
" self.kernel_size_number = kernel_size * kernel_size\n",
" self.out_channels = out_channels\n",
" self.padding = padding\n",
" self.dilation = dilation\n",
" self.stride = stride\n",
" self.n_channels = n_channels\n",
" self.weights = nn.Parameter(\n",
" torch.Tensor(self.out_channels, self.n_channels, self.kernel_size**2)\n",
" )\n",
"\n",
" def __repr__(self):\n",
" return (\n",
" f\"Conv2d(n_channels={self.n_channels}, out_channels={self.out_channels}, \"\n",
" f\"kernel_size={self.kernel_size})\"\n",
" )\n",
" \n",
" def forward(self, x):\n",
" width = self.calculate_new_width(x)\n",
" height = self.calculate_new_height(x)\n",
" windows = self.calculate_windows(x)\n",
" \n",
" result = torch.zeros(\n",
" [x.shape[0] * self.out_channels, width, height],\n",
" dtype=torch.float32, device=x.device\n",
" )\n",
"\n",
" # import pdb; pdb.set_trace()\n",
" for channel in range(x.shape[1]):\n",
" for i_conv_n in range(self.out_channels):\n",
" # print(channel, i_conv_n)\n",
" xx = torch.matmul(windows[channel], self.weights[i_conv_n][channel]) \n",
" xx = xx.view((-1, width, height))\n",
" \n",
" xx_stride = slice(i_conv_n * xx.shape[0], (i_conv_n + 1) * xx.shape[0])\n",
" result[xx_stride] += xx\n",
"\n",
" result = result.view((x.shape[0], self.out_channels, width, height))\n",
" return result \n",
"\n",
" def calculate_windows(self, x):\n",
" windows = F.unfold(\n",
" x,\n",
" kernel_size=(self.kernel_size, self.kernel_size),\n",
" padding=(self.padding, self.padding),\n",
" dilation=(self.dilation, self.dilation),\n",
" stride=(self.stride, self.stride)\n",
" )\n",
"\n",
" windows = (windows\n",
" .transpose(1, 2)\n",
" .contiguous().view((-1, x.shape[1], int(self.kernel_size**2)))\n",
" .transpose(0, 1)\n",
" )\n",
" return windows\n",
"\n",
" def calculate_new_width(self, x):\n",
" return (\n",
" (x.shape[2] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)\n",
" // self.stride\n",
" ) + 1\n",
"\n",
" def calculate_new_height(self, x):\n",
" return (\n",
" (x.shape[3] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)\n",
" // self.stride\n",
" ) + 1"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randint(0, 255, (1, 3, 512, 512), device='cuda') / 255"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Conv2d(n_channels=3, out_channels=16, kernel_size=3)"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv = Conv2d(3, 16, 3)\n",
"conv.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 32.5 ms, sys: 4.2 ms, total: 36.7 ms\n",
"Wall time: 35.5 ms\n"
]
}
],
"source": [
"%%time\n",
"out = conv(x)\n",
"out.mean().backward()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## JIT version"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.jit as jit\n",
"\n",
"class Conv2d(jit.ScriptModule):\n",
" def __init__(\n",
" self, n_channels, out_channels, kernel_size, dilation=1, padding=0, stride=1\n",
" ):\n",
" super().__init__()\n",
"\n",
" self.kernel_size = kernel_size\n",
" self.kernel_size_number = kernel_size * kernel_size\n",
" self.out_channels = out_channels\n",
" self.padding = padding\n",
" self.dilation = dilation\n",
" self.stride = stride\n",
" self.n_channels = n_channels\n",
" self.weights = nn.Parameter(\n",
" torch.Tensor(self.out_channels, self.n_channels, self.kernel_size**2)\n",
" )\n",
"\n",
" def __repr__(self):\n",
" return (\n",
" f\"Conv2d(n_channels={self.n_channels}, out_channels={self.out_channels}, \"\n",
" f\"kernel_size={self.kernel_size})\"\n",
" )\n",
" \n",
" @jit.script_method\n",
" def forward(self, x):\n",
" width = self.calculate_new_width(x)\n",
" height = self.calculate_new_height(x)\n",
" windows = self.calculate_windows(x)\n",
" \n",
" result = torch.zeros(\n",
" [x.shape[0] * self.out_channels, width, height],\n",
" dtype=torch.float32, device=x.device\n",
" )\n",
"\n",
" for channel in range(x.shape[1]):\n",
" for i_conv_n in range(self.out_channels):\n",
" xx = torch.matmul(windows[channel], self.weights[i_conv_n][channel]) \n",
" xx = xx.view((-1, width, height))\n",
" \n",
" xx_stride = slice(i_conv_n * xx.shape[0], (i_conv_n + 1) * xx.shape[0])\n",
" result[xx_stride] += xx\n",
"\n",
" result = result.view((x.shape[0], self.out_channels, width, height))\n",
" return result\n",
"\n",
" def calculate_windows(self, x):\n",
" windows = F.unfold(\n",
" x,\n",
" kernel_size=(self.kernel_size, self.kernel_size),\n",
" padding=(self.padding, self.padding),\n",
" dilation=(self.dilation, self.dilation),\n",
" stride=(self.stride, self.stride)\n",
" )\n",
"\n",
" windows = (windows\n",
" .transpose(1, 2)\n",
" .contiguous().view((-1, x.shape[1], int(self.kernel_size**2)))\n",
" .transpose(0, 1)\n",
" )\n",
" return windows\n",
"\n",
" def calculate_new_width(self, x):\n",
" return (\n",
" (x.shape[2] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)\n",
" // self.stride\n",
" ) + 1\n",
"\n",
" def calculate_new_height(self, x):\n",
" return (\n",
" (x.shape[3] + 2 * self.padding - self.dilation * (self.kernel_size - 1) - 1)\n",
" // self.stride\n",
" ) + 1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"x = torch.randint(0, 255, (1, 3, 512, 512), device='cuda') / 255"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Conv2d(n_channels=3, out_channels=16, kernel_size=3)"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"conv = Conv2d(3, 16, 3)\n",
"conv.cuda()"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 13.8 ms, sys: 4.79 ms, total: 18.6 ms\n",
"Wall time: 17.4 ms\n"
]
}
],
"source": [
"%%time\n",
"out = conv(x)\n",
"out.mean().backward()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.9"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment