Skip to content

Instantly share code, notes, and snippets.

@kengz
Last active August 11, 2019 18:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kengz/ba6ec0b097eb17f6beaa45c4ce82e7a7 to your computer and use it in GitHub Desktop.
Save kengz/ba6ec0b097eb17f6beaa45c4ce82e7a7 to your computer and use it in GitHub Desktop.
SAC loss functions
def calc_q(self, state, action, net):
'''Forward-pass to calculate the predicted state-action-value from q1_net.'''
q_pred = net(state, action).view(-1)
return q_pred
def calc_q_targets(self, batch):
'''Q_tar = r + gamma * (target_Q(s', a') - alpha * log pi(a'|s'))'''
next_states = batch['next_states']
with torch.no_grad():
pdparams = self.calc_pdparam(next_states)
action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams)
next_log_probs, next_actions = self.calc_log_prob_action(action_pd)
next_actions = self.guard_q_actions(next_actions) # non-reparam discrete actions need to be converted into one-hot
next_target_q1_preds = self.calc_q(next_states, next_actions, self.target_q1_net)
next_target_q2_preds = self.calc_q(next_states, next_actions, self.target_q2_net)
next_target_q_preds = torch.min(next_target_q1_preds, next_target_q2_preds)
q_targets = batch['rewards'] + self.gamma * (1 - batch['dones']) * (next_target_q_preds - self.alpha * next_log_probs)
return q_targets
def calc_reg_loss(self, preds, targets):
'''Calculate the regression loss for V and Q values, using the same loss function from net_spec'''
assert preds.shape == targets.shape, f'{preds.shape} != {targets.shape}'
reg_loss = self.net.loss_fn(preds, targets)
return reg_loss
def calc_policy_loss(self, batch, log_probs, reparam_actions):
'''policy_loss = alpha * log pi(f(a)|s) - Q1(s, f(a)), where f(a) = reparametrized action'''
states = batch['states']
q1_preds = self.calc_q(states, reparam_actions, self.q1_net)
q2_preds = self.calc_q(states, reparam_actions, self.q2_net)
q_preds = torch.min(q1_preds, q2_preds)
policy_loss = (self.alpha * log_probs - q_preds).mean()
return policy_loss
def calc_alpha_loss(self, log_probs):
alpha_loss = - (self.log_alpha * (log_probs.detach() + self.target_entropy)).mean()
return alpha_loss
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment