Skip to content

Instantly share code, notes, and snippets.

@yueyericardo
Last active August 16, 2023 16:08
Show Gist options
  • Save yueyericardo/0d89a3a74c874c68a5a8729891a459a8 to your computer and use it in GitHub Desktop.
Save yueyericardo/0d89a3a74c874c68a5a8729891a459a8 to your computer and use it in GitHub Desktop.
====================================================================== erf ======================================================================
/opt/conda/lib/python3.8/site-packages/torch/profiler/profiler.py:395: UserWarning: use_cuda is deprecated, use activities argument instead
warn("use_cuda is deprecated, use activities argument instead")
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 137.861ms 37.50% 137.861ms 172.326us 800
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 76.214ms 20.73% 76.214ms 127.023us 600
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 50.280ms 13.68% 50.280ms 125.700us 400
Memcpy DtoD (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 27.599ms 7.51% 27.599ms 137.995us 200
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 25.417ms 6.91% 25.417ms 127.085us 200
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 25.224ms 6.86% 25.224ms 126.120us 200
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 25.060ms 6.82% 25.060ms 125.300us 200
cudaLaunchKernel 50.85% 164.824ms 50.85% 164.824ms 68.677us 0.000us 0.00% 0.000us 0.000us 2400
cudaMemcpyAsync 4.05% 13.143ms 4.05% 13.143ms 65.715us 0.000us 0.00% 0.000us 0.000us 200
cudaDeviceSynchronize 45.09% 146.163ms 45.09% 146.163ms 146.163ms 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 324.130ms
Self CUDA time total: 367.655ms
erf : 3.567 ms/step
====================================================================== erf_scripted ======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
CudaCodeGen::kernel3(CudaCodeGen::Tensor<float, 2>, ... 0.00% 0.000us 0.00% 0.000us 0.000us 91.257ms 49.40% 91.257ms 456.285us 200
CudaCodeGen::kernel2(CudaCodeGen::Tensor<float, 2>, ... 0.00% 0.000us 0.00% 0.000us 0.000us 66.889ms 36.21% 66.889ms 334.445us 200
CudaCodeGen::kernel1(CudaCodeGen::Tensor<float, 2>, ... 0.00% 0.000us 0.00% 0.000us 0.000us 26.575ms 14.39% 26.575ms 132.875us 200
cudaDeviceSynchronize 100.00% 124.957ms 100.00% 124.957ms 124.957ms 0.000us 0.00% 0.000us 0.000us 1
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 124.957ms
Self CUDA time total: 184.721ms
erf_scripted : 1.056 ms/step
====================================================================== linear_erf ======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 634.280ms 20.54% 634.280ms 186.553us 3400
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_2... 0.00% 0.000us 0.00% 0.000us 0.000us 623.373ms 20.19% 623.373ms 311.687us 2000
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_1... 0.00% 0.000us 0.00% 0.000us 0.000us 364.082ms 11.79% 364.082ms 260.059us 1400
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_1... 0.00% 0.000us 0.00% 0.000us 0.000us 359.709ms 11.65% 359.709ms 256.935us 1400
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 276.403ms 8.95% 276.403ms 125.638us 2200
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 202.094ms 6.54% 202.094ms 126.309us 1600
void at::native::reduce_kernel<128, 4, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 157.646ms 5.10% 157.646ms 112.604us 1400
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 124.766ms 4.04% 124.766ms 124.766us 1000
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 111.761ms 3.62% 111.761ms 111.761us 1000
Memcpy DtoD (Device -> Device) 0.00% 0.000us 0.00% 0.000us 0.000us 82.061ms 2.66% 82.061ms 136.768us 600
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 74.751ms 2.42% 74.751ms 124.585us 600
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 73.773ms 2.39% 73.773ms 184.433us 400
Memset (Device) 0.00% 0.000us 0.00% 0.000us 0.000us 3.427ms 0.11% 3.427ms 1.008us 3400
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla... 0.04% 1.019ms 0.04% 1.019ms 0.106us 0.000us 0.00% 0.000us 0.000us 9600
cudaFuncSetAttribute 0.02% 558.000us 0.02% 558.000us 0.116us 0.000us 0.00% 0.000us 0.000us 4800
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.763s
Self CUDA time total: 3.088s
linear_erf : 29.099 ms/step
====================================================================== linear_erf_scripted ======================================================================
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Name Self CPU % Self CPU CPU total % CPU total CPU time avg Self CUDA Self CUDA % CUDA total CUDA time avg # of Calls
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 930.431ms 28.38% 930.431ms 186.086us 5000
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_2... 0.00% 0.000us 0.00% 0.000us 0.000us 667.446ms 20.36% 667.446ms 303.385us 2200
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_1... 0.00% 0.000us 0.00% 0.000us 0.000us 363.989ms 11.10% 363.989ms 259.992us 1400
void cutlass::Kernel<cutlass_80_tensorop_s1688gemm_1... 0.00% 0.000us 0.00% 0.000us 0.000us 360.191ms 10.99% 360.191ms 257.279us 1400
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 200.672ms 6.12% 200.672ms 125.420us 1600
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 199.777ms 6.09% 199.777ms 124.861us 1600
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 186.602ms 5.69% 186.602ms 186.602us 1000
void at::native::reduce_kernel<128, 4, at::native::R... 0.00% 0.000us 0.00% 0.000us 0.000us 167.304ms 5.10% 167.304ms 119.503us 1400
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 124.144ms 3.79% 124.144ms 124.144us 1000
CudaCodeGen::kernel1(CudaCodeGen::Tensor<float, 2>, ... 0.00% 0.000us 0.00% 0.000us 0.000us 70.703ms 2.16% 70.703ms 117.838us 600
Memset (Device) 0.00% 0.000us 0.00% 0.000us 0.000us 3.628ms 0.11% 3.628ms 1.008us 3600
CudaCodeGen::kernel4(CudaCodeGen::Tensor<float, 2>, ... 0.00% 0.000us 0.00% 0.000us 0.000us 2.871ms 0.09% 2.871ms 4.785us 600
void at::native::vectorized_elementwise_kernel<4, at... 0.00% 0.000us 0.00% 0.000us 0.000us 600.000us 0.02% 600.000us 3.000us 200
cudaOccupancyMaxActiveBlocksPerMultiprocessorWithFla... 0.06% 1.488ms 0.06% 1.488ms 0.149us 0.000us 0.00% 0.000us 0.000us 10000
cudaFuncSetAttribute 0.02% 435.000us 0.02% 435.000us 0.087us 0.000us 0.00% 0.000us 0.000us 5000
------------------------------------------------------- ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------ ------------
Self CPU time total: 2.685s
Self CUDA time total: 3.278s
linear_erf_scripted : 31.473 ms/step
import argparse
import time
import torch
from torch.profiler import record_function, ProfilerActivity
torch._C._jit_set_nvfuser_single_node_mode(True)
torch._C._debug_set_autodiff_subgraph_inlining(False)
torch._C._jit_set_nvfuser_enabled(True)
# -----------------------------------------------------------------------
# benchmark utils
def timeit(func, *args, steps=200, warmup=10, show_profile=False, label=None, label_padding=35):
# warmup
for _ in range(warmup):
func(*args)
# start timer
torch.cuda.synchronize()
start = time.time()
if show_profile:
# profile the number of kernels
print("\n" + "=" * 70 + " " + label + " " + "=" * 70)
with torch.profiler.profile(activities=[ProfilerActivity.CUDA], use_cuda=True) as prof:
with record_function("run_total"):
for _ in range(steps):
func(*args)
print(prof.key_averages().table(sort_by="self_cuda_time_total", max_src_column_width=200, row_limit=15))
else:
# otherwise just run benchmark
for _ in range(steps):
func(*args)
# stop timer
torch.cuda.synchronize()
time_ms = ((time.time() - start) / steps) * 1000
print(f"{label.ljust(label_padding)}: {time_ms:.3f} ms/step")
# -----------------------------------------------------------------------
# erf
class Erf(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return torch.erf(x)
erf = Erf()
erf_scripted = torch.jit.script(erf)
# -----------------------------------------------------------------------
# linear_erf
# this could be a workaround to make torchscript fuse again
class IdentityViewAs(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, x):
return x.view_as(x)
linear_erf = torch.nn.Sequential(
torch.nn.Linear(512, 512),
Erf(),
torch.nn.Linear(512, 512),
Erf(),
torch.nn.Linear(512, 512),
Erf(),
torch.nn.Linear(512, 512),
# IdentityViewAs(), # torchscript could fuse if add an ViewAs
).to("cuda")
linear_erf_scripted = torch.jit.script(linear_erf)
# -----------------------------------------------------------------------
# some inputs
device = "cuda"
batch_size = 50000
x = torch.rand([batch_size, 512], device=device, requires_grad=True)
I_N = torch.ones_like(x)
def run(func, *args):
# forward
y = func(*args)
# 1st order derivative
(y__x,) = torch.autograd.grad(y, [x], I_N, create_graph=True)
# 2nd order derivative
(y__x__x,) = torch.autograd.grad(y__x, [x], I_N, create_graph=True)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("-p", "--profile", default=False, action="store_true")
args = parser.parse_args()
timeit(run, erf, x, label="erf", show_profile=args.profile)
timeit(run, erf_scripted, x, label="erf_scripted", show_profile=args.profile)
timeit(run, linear_erf, x, label="linear_erf", show_profile=args.profile)
timeit(run, linear_erf_scripted, x, label="linear_erf_scripted", show_profile=args.profile)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment