Skip to content

Instantly share code, notes, and snippets.

@kengz
Created August 10, 2019 06:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kengz/9b53ddbfe07fc44a02c08b92d19a54fd to your computer and use it in GitHub Desktop.
Save kengz/9b53ddbfe07fc44a02c08b92d19a54fd to your computer and use it in GitHub Desktop.
SAC training loop
def train_alpha(self, alpha_loss):
'''Custom method to train the alpha variable'''
self.alpha_lr_scheduler.step(epoch=self.body.env.clock.frame)
self.alpha_optim.zero_grad()
alpha_loss.backward()
self.alpha_optim.step()
self.alpha = self.log_alpha.detach().exp()
def train(self):
'''Train actor critic by computing the loss in batch efficiently'''
if util.in_eval_lab_modes():
return np.nan
clock = self.body.env.clock
if self.to_train == 1:
for _ in range(self.training_iter):
batch = self.sample()
clock.set_batch_size(len(batch))
states = batch['states']
actions = self.guard_q_actions(batch['actions'])
q_targets = self.calc_q_targets(batch)
# Q-value loss for both Q nets
q1_preds = self.calc_q(states, actions, self.q1_net)
q1_loss = self.calc_reg_loss(q1_preds, q_targets)
self.q1_net.train_step(q1_loss, self.q1_optim, self.q1_lr_scheduler, clock=clock, global_net=self.global_q1_net)
q2_preds = self.calc_q(states, actions, self.q2_net)
q2_loss = self.calc_reg_loss(q2_preds, q_targets)
self.q2_net.train_step(q2_loss, self.q2_optim, self.q2_lr_scheduler, clock=clock, global_net=self.global_q2_net)
# policy loss
action_pd = policy_util.init_action_pd(self.body.ActionPD, self.calc_pdparam(states))
log_probs, reparam_actions = self.calc_log_prob_action(action_pd, reparam=True)
policy_loss = self.calc_policy_loss(batch, log_probs, reparam_actions)
self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net)
# alpha loss
alpha_loss = self.calc_alpha_loss(log_probs)
self.train_alpha(alpha_loss)
loss = q1_loss + q2_loss + policy_loss + alpha_loss
# update target networks
self.update_nets()
# update PER priorities if availalbe
self.try_update_per(torch.min(q1_preds, q2_preds), q_targets)
# reset
self.to_train = 0
logger.debug(f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.env.total_reward}, loss: {loss:g}')
return loss.item()
else:
return np.nan
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment