Skip to content

Instantly share code, notes, and snippets.

@kengz
Last active August 11, 2019 18:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kengz/7ed22834b68a238b73f1f2de9ed1f70e to your computer and use it in GitHub Desktop.
Save kengz/7ed22834b68a238b73f1f2de9ed1f70e to your computer and use it in GitHub Desktop.
SAC networks
def init_nets(self, global_nets=None):
'''
Networks: net(actor/policy), q1_net, target_q1_net, q2_net, target_q2_net
All networks are separate, and have the same hidden layer architectures and optim specs, so tuning is minimal
'''
self.shared = False # SAC does not share networks
NetClass = getattr(net, self.net_spec['type'])
# main actor network
self.net = NetClass(self.net_spec, self.body.state_dim, net_util.get_out_dim(self.body))
self.net_names = ['net']
# two critic Q-networks to mitigate positive bias in q_loss and speed up training, uses q_net.py with prefix Q
QNetClass = getattr(net, 'Q' + self.net_spec['type'])
q_in_dim = [self.body.state_dim, self.body.action_dim]
self.q1_net = QNetClass(self.net_spec, q_in_dim, 1)
self.target_q1_net = QNetClass(self.net_spec, q_in_dim, 1)
self.q2_net = QNetClass(self.net_spec, q_in_dim, 1)
self.target_q2_net = QNetClass(self.net_spec, q_in_dim, 1)
self.net_names += ['q1_net', 'target_q1_net', 'q2_net', 'target_q2_net']
net_util.copy(self.q1_net, self.target_q1_net)
net_util.copy(self.q2_net, self.target_q2_net)
# temperature variable to be learned, and its target entropy
self.log_alpha = torch.zeros(1, requires_grad=True, device=self.net.device)
self.alpha = self.log_alpha.detach().exp()
self.target_entropy = - np.product(self.body.action_space.shape)
# init net optimizer and its lr scheduler
self.optim = net_util.get_optim(self.net, self.net.optim_spec)
self.lr_scheduler = net_util.get_lr_scheduler(self.optim, self.net.lr_scheduler_spec)
self.q1_optim = net_util.get_optim(self.q1_net, self.q1_net.optim_spec)
self.q1_lr_scheduler = net_util.get_lr_scheduler(self.q1_optim, self.q1_net.lr_scheduler_spec)
self.q2_optim = net_util.get_optim(self.q2_net, self.q2_net.optim_spec)
self.q2_lr_scheduler = net_util.get_lr_scheduler(self.q2_optim, self.q2_net.lr_scheduler_spec)
self.alpha_optim = net_util.get_optim(self.log_alpha, self.net.optim_spec)
self.alpha_lr_scheduler = net_util.get_lr_scheduler(self.alpha_optim, self.net.lr_scheduler_spec)
net_util.set_global_nets(self, global_nets)
self.post_init_nets()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment