def _process_log(self, src_dict, dest_dict): for k, v in src_dict.items(): if isinstance(v, (int, float)): dest_dict.setdefault(k, meter_utils.AverageValueMeter()) dest_dict[k].add(float(v)) else: dest_dict[k] = v
def val(self): self.model.eval() logs = OrderedDict() sum_loss = meter_utils.AverageValueMeter() logger.info('Val on validation set...') self.batch_timer.clear() self.data_timer.clear() self.batch_timer.tic() self.data_timer.tic() for step, batch in enumerate(self.val_data): self.data_timer.toc() inputs, gts, _ = self.batch_processor(self, batch) _, saved_for_loss = self.model(*inputs) self.batch_timer.toc() loss, saved_for_log = self.model.module.build_loss(saved_for_loss, *gts) sum_loss.add(loss.item()) self._process_log(saved_for_log, logs) if step % self.params.print_freq == 0: self._print_log(step, logs, 'Validation', max_n_batch=len(self.val_data)) self.data_timer.tic() self.batch_timer.tic() mean, std = sum_loss.value() logger.info('\n\nValidation loss: mean: {}, std: {}'.format(mean, std))
def _train_one_epoch(self): self.batch_timer.clear() self.data_timer.clear() self.batch_timer.tic() self.data_timer.tic() total_loss = meter_utils.AverageValueMeter() for step, batch in enumerate(self.train_data): inputs, gts, _ = self.batch_processor(self, batch) self.data_timer.toc() # forward output, saved_for_loss = self.model(*inputs) loss, saved_for_log = self.model.module.build_loss( saved_for_loss, *gts) # backward self.optimizer.zero_grad() loss.backward() total_loss.add(loss.item()) # clip grad if not np.isinf(self.params.max_grad_norm): max_norm = nn.utils.clip_grad_norm(self.model.parameters(), self.params.max_grad_norm, float('inf')) saved_for_log['max_grad'] = max_norm self.optimizer.step(None) self._process_log(saved_for_log, self.log_values) self.batch_timer.toc() # print log reset = False if step % self.params.print_freq == 0: self._print_log(step, self.log_values, title='Training', max_n_batch=self.batch_per_epoch) reset = True if step % self.params.save_freq_step == 0 and step > 0: save_to = os.path.join( self.params.save_dir, 'ckpt_{}.h5.ckpt'.format((self.last_epoch - 1) * self.batch_per_epoch + step)) self._save_ckpt(save_to) if reset: self._reset_log(self.log_values) self.data_timer.tic() self.batch_timer.tic() total_loss, std = total_loss.value() return total_loss
def _val_one_epoch(self, n_batch): training_mode = self.model.training self.model.eval() logs = OrderedDict() sum_loss = meter_utils.AverageValueMeter() logger.info('Val on validation set...') self.batch_timer.clear() self.data_timer.clear() self.batch_timer.tic() self.data_timer.tic() for step, batch in enumerate(self.val_data): self.data_timer.toc() if step > n_batch: break inputs, gts, _ = self.batch_processor(self, batch) _, saved_for_loss = self.model(*inputs) self.batch_timer.toc() loss, saved_for_log = self.model.module.build_loss( saved_for_loss, *gts) sum_loss.add(loss.item()) self._process_log(saved_for_log, logs) if step % self.params.print_freq == 0 or step == len( self.val_data) - 1: self._print_log(step, logs, 'Validation', max_n_batch=min(n_batch, len(self.val_data))) self.data_timer.tic() self.batch_timer.tic() mean, std = sum_loss.value() logger.info('Validation loss: mean: {}, std: {}'.format(mean, std)) self.model.train(mode=training_mode) if self.params.subnet_name != 'keypoint_subnet': self.model.module.freeze_bn() return mean