def run_batch(self, data: dict) -> Dict[bool, int]: self.model.train() if self.train else self.model.eval() self.optimizer.zero_grad() if 'type' in data: for opt in data['type']: self.opt_count_dict[opt] += 1 encoder_hidden, decoder_hidden = init_hidden_states(self.param) model_output = self.model.forward(data, encoder_hidden, decoder_hidden, self.step) loss, loss_count = self.run_loss(data, model_output) if self.train and not self.print_only: loss.backward() # calculates the gradients if self.max_grad_norm > 0: params = chain.from_iterable( [group['params'] for group in self.optimizer.param_groups]) torch.nn.utils.clip_grad_norm_(params, self.max_grad_norm) self.optimizer.step() if self.step % self.lr_step == 0: self.update_optimizer_scheduled() local_stop_counter = dict() if 'stops' in data: local_stop_counter[True] = int( torch.sum(data['stops']).detach().cpu().item()) local_stop_counter[False] = data['stops'].numel( ) - local_stop_counter[True] self.loss_computer += loss.item(), loss_count self.step_elapsed += 1 self.step += 1 return local_stop_counter
def initialize(self): self.encoder_hidden, self.decoder_hidden = init_hidden_states( self.param) self.images = []