Skip to content

Instantly share code, notes, and snippets.

@iacolippo
Created October 19, 2018 12:44
Show Gist options
  • Save iacolippo/c278680717fb4622f958e37caa721fe3 to your computer and use it in GitHub Desktop.
Save iacolippo/c278680717fb4622f958e37caa721fe3 to your computer and use it in GitHub Desktop.
RuntimeError: variable impl does not have is_contiguous Pytorch C++ extension
#include <torch/extension.h>
#include <cmath>
#include <iostream>
#include <vector>
at::Tensor ex_forward(
at::Tensor input
) {
auto n_samples = input.size(0);
auto n_features = input.size(1);
auto G = n_features / 2;
auto M = 2;
at::Tensor temp = at::zeros({n_samples, G, 2});
at::Tensor slice1 = input.slice(1, 0, n_features, 2) + input.slice(1, 1, n_features, 2);
at::Tensor slice2 = input.slice(1, 0, n_features, 2) - input.slice(1, 1, n_features, 2);
temp = at::stack({slice1, slice2}, 2);
auto res = temp;
for (auto dumb_idx = 0; dumb_idx < std::log2(n_features) + 1; dumb_idx++) {
temp = at::zeros({n_samples, G / 2, M * 2});
slice1 = res.slice(2, 0, M, 2).slice(1, 0, G, 2);
slice2 = res.slice(2, 0, M, 2).slice(1, 1, G, 2);
auto mesh1 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(0, 2 * M, 4))});
auto mesh2 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(1, 2 * M, 4))});
temp.index_put_(mesh1, slice1 + slice2);
temp.index_put_(mesh2, slice1 - slice2);
slice1 = res.slice(2, 1, M, 2).slice(1, 0, G, 2);
slice2 = res.slice(2, 1, M, 2).slice(1, 1, G, 2);
mesh1 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(2, 2 * M, 4))});
mesh2 = at::meshgrid({at::_cast_Long(at::arange(0, n_samples, 1)), at::_cast_Long(at::arange(0, G/2, 1)), at::_cast_Long(at::arange(3, 2 * M, 4))});
temp.index_put_(mesh1, slice1 - slice2);
temp.index_put_(mesh2, slice1 + slice2);
res = temp;
G = G / 2;
M = M * 2;
}
at::Tensor output = temp.select(1, 0); // select index 0 along dim 1
return output * pow(std::sqrt(n_features), -1);
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
m.def("forward", &ex_forward, "EX forward");
}
from setuptools import setup
from torch.utils.cpp_extension import CppExtension, BuildExtension
setup(name='ex',
ext_modules=[CppExtension('ex', ['ex.cpp'], extra_compile_args=['-g', '-O0'])],
cmdclass={'build_ext': BuildExtension}
)
import torch
import ex
class EXATen(torch.autograd.Function):
@staticmethod
def forward(ctx, input):
outputs = ex.forward(input)
return outputs
x = torch.randn(3, 8)
y1 = EXATen.apply(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment