Skip to content

Instantly share code, notes, and snippets.

@fdeheeger
Last active November 12, 2015 08:57
Show Gist options
  • Save fdeheeger/642ba27c666a497d039d to your computer and use it in GitHub Desktop.
Save fdeheeger/642ba27c666a497d039d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
################################################################################
# Interpolation class
# TODO : use a nicer n-dim method (like multilinear interpolation)
from scipy.interpolate import RectBivariateSpline, UnivariateSpline
from dolointerpolation.multilinear_cython import multilinear_interpolation
class UnivariateSpline(UnivariateSpline):
'''extended UnivariateSpline class,
where spline evaluation works uses input broadcast
and returns an output with a coherent shape.
'''
#@profile
def __call__(self, *x):
# flatten the inputs after saving their shape:
shape = np.array(x).shape
# Evaluate the spline and reconstruct the dimension:
z = super(UnivariateSpline, self).__call__(np.ravel(x))
return z.reshape(shape)
#-----
#-----
def interp_on_state(self, A):
'''returns an interpolating function of matrix A, assuming that A
is expressed on the state grid `self.state_grid`
the shape of A should be (len(g) for g in self.state_grid)
'''
# Check the dimension of A:
expect_shape = self._state_grid_shape
if A.shape != expect_shape:
raise ValueError('array `A` should be of shape {:s}, not {:s}'.format(
str(expect_shape), str(A.shape)) )
if len(expect_shape) == 1:
A_interp = UnivariateSpline(self.state_grid[0], A, ext=3)
return A_interp
elif len(expect_shape) <= 5:
A_interp = MlinInterpolator(*self.state_grid)
A_interp.set_values(A)
return A_interp
# if len(expect_shape) == 2:
# x1_grid = self.state_grid[0]
# x2_grid = self.state_grid[1]
# A_interp = RectBivariateSplineBc(x1_grid, x2_grid, A, kx=1, ky=1)
# return A_interp
else:
raise NotImplementedError('interpolation for state dimension >5'
' is not implemented.')
# end interp_on_state()
@pierre-haessig
Copy link

As to adding the UnivariateSpline interpolator, it is a very useful addition, since it removes the need to compile the multilinear_cython module at least for 1D case. However, it would be better to have a way to choose which interpolator to use. What do you think?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment