Skip to content

Instantly share code, notes, and snippets.

@kengz
Last active August 11, 2019 18:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kengz/ab75510074d9ba788896d25127dd6c3c to your computer and use it in GitHub Desktop.
Save kengz/ab75510074d9ba788896d25127dd6c3c to your computer and use it in GitHub Desktop.
SAC log probs
def calc_log_prob_action(self, action_pd, reparam=False):
'''Calculate log_probs and actions with option to reparametrize from paper eq. 11'''
samples = action_pd.rsample() if reparam else action_pd.sample()
if self.body.is_discrete: # this is straightforward using GumbelSoftmax
actions = samples
log_probs = action_pd.log_prob(actions)
else:
mus = samples
actions = self.scale_action(torch.tanh(mus))
# paper Appendix C. Enforcing Action Bounds for continuous actions
log_probs = (action_pd.log_prob(mus) - torch.log(1 - actions.pow(2) + 1e-6).sum(1))
return log_probs, actions
# ... for discrete action, GumbelSoftmax distribution
class GumbelSoftmax(distributions.RelaxedOneHotCategorical):
'''
A differentiable Categorical distribution using reparametrization trick with Gumbel-Softmax
Explanation http://amid.fish/assets/gumbel.html
NOTE: use this in place PyTorch's RelaxedOneHotCategorical distribution since its log_prob is not working right (returns positive values)
Papers:
[1] The Concrete Distribution: A Continuous Relaxation of Discrete Random Variables (Maddison et al, 2017)
[2] Categorical Reparametrization with Gumbel-Softmax (Jang et al, 2017)
'''
def sample(self, sample_shape=torch.Size()):
'''Gumbel-softmax sampling. Note rsample is inherited from RelaxedOneHotCategorical'''
u = torch.empty(self.logits.size(), device=self.logits.device, dtype=self.logits.dtype).uniform_(0, 1)
noisy_logits = self.logits - torch.log(-torch.log(u))
return torch.argmax(noisy_logits, dim=-1)
def log_prob(self, value):
'''value is one-hot or relaxed'''
if value.shape != self.logits.shape:
value = F.one_hot(value.long(), self.logits.shape[-1]).float()
assert value.shape == self.logits.shape
return - torch.sum(- value * F.log_softmax(self.logits, -1), -1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment