コード例 #1
0
 def training_step(self, xs=None, ys=None, loss=None, retain_graph=False, global_net=None):
     '''
     Takes a single training step: one forward and one backwards pass. Both x and y are lists of the same length, one x and y per environment
     '''
     self.train()
     self.zero_grad()
     self.optim.zero_grad()
     if loss is None:
         outs = self(xs)
         total_loss = torch.tensor(0.0, device=self.device)
         for out, y in zip(outs, ys):
             loss = self.loss_fn(out, y)
             total_loss += loss
         loss = total_loss
     assert not torch.isnan(loss).any(), loss
     if net_util.to_assert_trained():
         assert_trained = net_util.gen_assert_trained(self.model_body)
     loss.backward(retain_graph=retain_graph)
     if self.clip_grad:
         logger.debug(f'Clipping gradient: {self.clip_grad_val}')
         torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
     if global_net is None:
         self.optim.step()
     else:  # distributed training with global net
         net_util.push_global_grad(self, global_net)
         self.optim.step()
         net_util.pull_global_param(self, global_net)
     if net_util.to_assert_trained():
         assert_trained(self.model_body, loss)
     logger.debug(f'Net training_step loss: {loss}')
     return loss
コード例 #2
0
ファイル: convnet.py プロジェクト: krishpop/SLM-Lab
 def training_step(self,
                   x=None,
                   y=None,
                   loss=None,
                   retain_graph=False,
                   global_net=None):
     '''Takes a single training step: one forward and one backwards pass'''
     self.train()
     self.zero_grad()
     self.optim.zero_grad()
     if loss is None:
         out = self(x)
         loss = self.loss_fn(out, y)
     assert not torch.isnan(loss).any(), loss
     if net_util.to_assert_trained():
         assert_trained = net_util.gen_assert_trained(self.conv_model)
     loss.backward(retain_graph=retain_graph)
     if self.clip_grad:
         logger.debug(f'Clipping gradient: {self.clip_grad_val}')
         torch.nn.utils.clip_grad_norm_(self.parameters(),
                                        self.clip_grad_val)
     if global_net is None:
         self.optim.step()
     else:  # distributed training with global net
         net_util.push_global_grad(self, global_net)
         self.optim.step()
         net_util.pull_global_param(self, global_net)
     if net_util.to_assert_trained():
         assert_trained(self.conv_model, loss)
     logger.debug(f'Net training_step loss: {loss}')
     return loss
コード例 #3
0
 def training_step(self, x=None, y=None, loss=None, retain_graph=False, global_net=None):
     '''
     Takes a single training step: one forward and one backwards pass
     More most RL usage, we have custom, often complicated, loss functions. Compute its value and put it in a pytorch tensor then pass it in as loss
     '''
     self.train()
     self.zero_grad()
     self.optim.zero_grad()
     if loss is None:
         out = self(x)
         loss = self.loss_fn(out, y)
     assert not torch.isnan(loss).any(), loss
     if net_util.to_assert_trained():
         # to accommodate split model in inherited classes
         model = getattr(self, 'model', None) or getattr(self, 'model_body')
         assert_trained = net_util.gen_assert_trained(model)
     loss.backward(retain_graph=retain_graph)
     if self.clip_grad:
         logger.debug(f'Clipping gradient: {self.clip_grad_val}')
         torch.nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val)
     if global_net is None:
         self.optim.step()
     else:  # distributed training with global net
         net_util.push_global_grad(self, global_net)
         self.optim.step()
         net_util.pull_global_param(self, global_net)
     if net_util.to_assert_trained():
         model = getattr(self, 'model', None) or getattr(self, 'model_body')
         assert_trained(model, loss)
     logger.debug(f'Net training_step loss: {loss}')
     return loss