def train(self, step: int) -> torch.Tensor: # update model self.optimizer.zero_grad() # flag for logging log_flag = step % self.log_interval == 0 # forward model loss, meta = self.forward(*to_device(next(self.train_dataset)), log_flag) # check loss nan if loss != loss: log('{} cur step NAN is occured'.format(step)) return loss.backward() self.clip_grad() self.optimizer.step() # logging if log_flag: # console logging self.console_log('train', meta, step) # tensorboard logging self.tensorboard_log('train', meta, step)
def validate(self, step: int): loss = 0. stat = defaultdict(float) for i in range(self.valid_max_step): # flag for logging log_flag = i % self.log_interval == 0 or i == self.valid_max_step - 1 # forward model with torch.no_grad(): batch_loss, meta = self.forward(*to_device(next(self.valid_dataset)), is_logging=log_flag) loss += batch_loss # update stat for key, (value, log_type) in meta.items(): if log_type == LogType.SCALAR: stat[key] += value # console logging of this step if (i + 1) % self.log_interval == 0: self.console_log('valid', meta, i + 1) meta_non_scalar = { key: (value, log_type) for key, (value, log_type) in meta.items() if not log_type == LogType.SCALAR } try: self.tensorboard_log('valid', meta_non_scalar, step) except OverflowError: pass # averaging stat loss /= self.valid_max_step for key in stat.keys(): stat[key] = stat[key] / self.valid_max_step # update best valid loss if loss < self.best_valid_loss: self.best_valid_loss = loss # console logging of total stat msg = 'step {} / total stat'.format(step) for key, value in sorted(stat.items()): msg += '\t{}: {:.6f}'.format(key, value) log(msg) # tensor board logging of scalar stat for key, value in stat.items(): self.writer.add_scalar('valid/{}'.format(key), value, global_step=step)
def validate(self, step: int): loss = 0. count = 0 stat = defaultdict(float) for i in range(self.valid_max_step): # forward model with torch.no_grad(): batch_loss, meta = self.forward(*to_device( next(self.valid_dataset)), is_logging=True) loss += batch_loss for key, (value, log_type) in meta.items(): if log_type == LogType.SCALAR: stat[key] += value if i % self.log_interval == 0 or i == self.valid_max_step - 1: self.console_log('valid', meta, i + 1) # averaging stat loss /= self.valid_max_step for key in stat.keys(): if key == 'loss': continue stat[key] = stat[key] / self.valid_max_step stat['loss'] = loss # update best valid loss if loss < self.best_valid_loss: self.best_valid_loss = loss # console logging of total stat msg = 'step {} / total stat'.format(step) for key, value in sorted(stat.items()): msg += '\t{}: {:.6f}'.format(key, value) log(msg) # tensor board logging of scalar stat for key, value in stat.items(): self.writer.add_scalar('valid/{}'.format(key), value, global_step=step)
def repeat(iterable): while True: for group in iterable: for x in group: yield to_device(x, 'cuda')