Skip to content

Instantly share code, notes, and snippets.

@dfm
Created March 2, 2021 23:37
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dfm/a2db466f46ab931947882b08b2f21558 to your computer and use it in GitHub Desktop.
Save dfm/a2db466f46ab931947882b08b2f21558 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
@peterroelants
Copy link

This is interesting, thank you for sharing.

I'm wondering about two things:

  1. What is the implication of converting from Aesara to Jax and back again. Would jnp.asarray and np.asarray implicate any memory overhead?
  2. Why call jax.jit on the function passed to JaxOp? For some reason I was assuming Aesara would compile down to Jax (in Jax Mode) and would take care of this. What compilation does Aesara provide?

@dfm
Copy link
Author

dfm commented Mar 3, 2021

@peterroelants: these questions are both moot if you're only using the jaxified version of the function. Perform is only called when evaluating the op using aesara. So this means that you could use this op as a deterministic using original PyMC3 or the Jax backend, and on the Jax backend this would reduce directly to just the jax function.

But to answer them directly:

  1. I don't think the asarray calls are strictly necessary, but I think that they don't introduce overhead because I think that would happen behind the scenes anyways, but I could well be wrong.
  2. Again the jit only matters if you also want to incorporate this into an aesara model that doesn't use jax otherwise. If you're using the jax backed, I don't think it would hurt to do this (?) but it's definitely not necessary in that case.

@bmorris3
Copy link

bmorris3 commented Jun 25, 2021

Thanks for this! One note I found while experimenting with this on aesara 2.0.12: jax_funcify_JaxOp seems to require an extra keyword arguments node and storage_map, so this tweak makes the code above work for me:

@jax_funcify.register(JaxOp)
def jax_funcify_JaxOp(op, *args, **kwargs):
    func = op.jax_fn
    return func

I hope that's sensible.

@dfm
Copy link
Author

dfm commented Jun 25, 2021

@bmorris3: Yeah - this interface has been a moving target so I haven't been following it too closely, so I'm not sure that I know enough to comment, but seems sensible enough :D

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment