-
-
Save dfm/a2db466f46ab931947882b08b2f21558 to your computer and use it in GitHub Desktop.
@peterroelants: these questions are both moot if you're only using the jaxified version of the function. Perform is only called when evaluating the op using aesara. So this means that you could use this op as a deterministic using original PyMC3 or the Jax backend, and on the Jax backend this would reduce directly to just the jax function.
But to answer them directly:
- I don't think the asarray calls are strictly necessary, but I think that they don't introduce overhead because I think that would happen behind the scenes anyways, but I could well be wrong.
- Again the jit only matters if you also want to incorporate this into an aesara model that doesn't use jax otherwise. If you're using the jax backed, I don't think it would hurt to do this (?) but it's definitely not necessary in that case.
Thanks for this! One note I found while experimenting with this on aesara 2.0.12: jax_funcify_JaxOp
seems to require an extra keyword arguments node
and storage_map
, so this tweak makes the code above work for me:
@jax_funcify.register(JaxOp)
def jax_funcify_JaxOp(op, *args, **kwargs):
func = op.jax_fn
return func
I hope that's sensible.
@bmorris3: Yeah - this interface has been a moving target so I haven't been following it too closely, so I'm not sure that I know enough to comment, but seems sensible enough :D
This is interesting, thank you for sharing.
I'm wondering about two things:
jnp.asarray
andnp.asarray
implicate any memory overhead?jax.jit
on the function passed toJaxOp
? For some reason I was assuming Aesara would compile down to Jax (in Jax Mode) and would take care of this. What compilation does Aesara provide?