class Trainer(BaseTrainer): r""" Trainer for person attribute recognition """ def __init__(self, config): super(Trainer, self).__init__(config) # Datamanager self.datamanager, params_data = build_datamanager( config['type'], config['data']) # model self.model, params_model = build_model( config, num_classes=len(self.datamanager.datasource.get_attribute()), device=self.device) # losses pos_ratio = torch.tensor( self.datamanager.datasource.get_weight('train')) self.criterion, params_loss = build_losses( config, pos_ratio=pos_ratio, num_attribute=len(self.datamanager.datasource.get_attribute())) # optimizer self.optimizer, params_optimizers = build_optimizers( config, self.model) # learing rate scheduler self.lr_scheduler, params_lr_scheduler = build_lr_scheduler( config, self.optimizer) # callbacks for freeze backbone if config['freeze']['enable']: self.freeze = FreezeLayers(self.model, config['freeze']['layers'], config['freeze']['epochs']) else: self.freeze = None # list of metrics self.lst_metrics = ['mA', 'accuracy', 'f1_score'] # track metric self.train_metrics = MetricTracker('loss', *self.lst_metrics) self.valid_metrics = MetricTracker('loss', *self.lst_metrics) # step log loss and accuracy self.log_step = (len(self.datamanager.get_dataloader('train')) // 5, len(self.datamanager.get_dataloader('val')) // 5) self.log_step = (self.log_step[0] if self.log_step[0] > 0 else 1, self.log_step[1] if self.log_step[1] > 0 else 1) # best accuracy and loss self.best_loss = None self.best_metrics = dict() for x in self.lst_metrics: self.best_metrics[x] = None # print config self._print_config( params_data=params_data, params_model=params_model, params_loss=params_loss, params_optimizers=params_optimizers, params_lr_scheduler=params_lr_scheduler, freeze_layers=False if self.freeze == None else True, clip_grad_norm_=self.config['clip_grad_norm_']['enable']) # send model to device self.model.to(self.device) self.criterion.to(self.device) # summary model summary(model=self.model, input_data=torch.zeros((self.datamanager.get_batch_size(), 3, self.datamanager.get_image_size()[0], self.datamanager.get_image_size()[1])), batch_dim=None, device='cuda' if self.use_gpu else 'cpu', print_func=self.logger.info, print_step=False) # resume model from last checkpoint if config['resume'] != '': self._resume_checkpoint(config['resume'], config['only_model']) def train(self): # begin train for epoch in range(self.start_epoch, self.epochs + 1): # freeze layer if self.freeze != None: self.freeze.on_epoch_begin(epoch) # train result = self._train_epoch(epoch) # valid result = self._valid_epoch(epoch) # learning rate if self.lr_scheduler is not None: if self.config['lr_scheduler']['start'] <= epoch: if isinstance(self.lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau): self.lr_scheduler.step(self.valid_metrics.avg('loss')) else: self.lr_scheduler.step() # add scalars to tensorboard self.writer.add_scalars('Loss', { 'Train': self.train_metrics.avg('loss'), 'Val': self.valid_metrics.avg('loss') }, global_step=epoch) for metric in self.lst_metrics: self.writer.add_scalars(metric, { 'Train': self.train_metrics.avg(metric), 'Val': self.valid_metrics.avg(metric) }, global_step=epoch) self.writer.add_scalar('lr', self.optimizer.param_groups[-1]['lr'], global_step=epoch) # logging result to console log = {'epoch': epoch} log.update(result) for key, value in log.items(): self.logger.info(' {:15s}: {}'.format(str(key), value)) # save model save_best_loss = False if self.best_loss == None or self.best_loss >= self.valid_metrics.avg( 'loss'): self.best_loss = self.valid_metrics.avg('loss') save_best_loss = True save_best = dict() for metric in self.lst_metrics: save_best[metric] = False if self.best_metrics[metric] == None or self.best_metrics[ metric] <= self.valid_metrics.avg(metric): self.best_metrics[metric] = self.valid_metrics.avg(metric) save_best[metric] = True self._save_checkpoint(epoch, save_best_loss, save_best) # save logs to drive if using colab if self.config['colab']: self._save_logs() # wait for tensorboard flush all metrics to file self.writer.flush() # time.sleep(1*60) self.writer.close() # save logs to drive if using colab if self.config['colab']: self._save_logs() # plot loss, accuracy and save them to plot.png in saved/logs/<run_id>/plot.png plot_loss_accuracy( dpath=self.cfg_trainer['log_dir'], list_dname=[self.run_id], path_folder=self.logs_dir_saved if self.config['colab'] == True else self.logs_dir, title=self.run_id + ': ' + self.config['model']['name'] + ", " + self.config['loss']['name'] + ", " + self.config['data']['name']) def _train_epoch(self, epoch): r""" Training step """ raise NotImplementedError def _valid_epoch(self, epoch): r""" Validation step """ raise NotImplementedError def test(self): r""" Test model after train TODO: """ logger = logging.getLogger('test') self.model.eval() preds = [] labels = [] if self.cfg_trainer['use_tqdm']: tqdm_callback = Tqdm( total=len(self.datamanager.get_dataloader('test'))) with torch.no_grad(): for batch_idx, (data, _labels) in enumerate( self.datamanager.get_dataloader('test')): if batch_idx == 5: break data, _labels = data.to(self.device), _labels.to(self.device) out = self.model(data) _preds = torch.sigmoid(out) preds.append(_preds) labels.append(_labels) if self.cfg_trainer['use_tqdm']: tqdm_callback.update() else: if (batch_idx + 1) % ( len(self.datamanager.get_dataloader('test')) // 10 + 1) or (batch_idx + 1) == len( self.datamanager.get_dataloader('test')) - 1: logger.info('Iter {}/{}'.format( batch_idx + 1, len(self.datamanager.get_dataloader('test')))) preds = torch.cat(preds, dim=0) labels = torch.cat(labels, dim=0) preds = preds.cpu().numpy() labels = labels.cpu().numpy() result_label, result_instance = recognition_metrics(labels, preds) log_test(logger.info, self.datamanager.datasource.get_attribute(), self.datamanager.datasource.get_weight('test'), result_label, result_instance) def _save_checkpoint(self, epoch, save_best_loss, save_best_metrics): r""" Save model to file """ state = { 'epoch': epoch, 'state_dict': self.model.state_dict(), 'loss': self.criterion.state_dict(), 'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.lr_scheduler.state_dict(), 'best_loss': self.best_loss } for metric in self.lst_metrics: state.update({'best_{}'.format(metric): self.best_metrics[metric]}) filename = os.path.join(self.checkpoint_dir, 'model_last.pth') self.logger.info("Saving last model: model_last.pth ...") torch.save(state, filename) if save_best_loss: filename = os.path.join(self.checkpoint_dir, 'model_best_loss.pth') self.logger.info( "Saving current best loss: model_best_loss.pth ...") torch.save(state, filename) for metric in self.lst_metrics: if save_best_metrics[metric]: filename = os.path.join(self.checkpoint_dir, 'model_best_{}.pth'.format(metric)) self.logger.info( "Saving current best {}: model_best_{}.pth ...".format( metric, metric)) torch.save(state, filename) def _resume_checkpoint(self, resume_path, only_model=False): r""" Load model from checkpoint """ if not os.path.exists(resume_path): raise FileExistsError("Resume path not exist!") self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path, map_location=self.map_location) self.model.load_state_dict(checkpoint['state_dict']) if only_model: self.logger.info("Pretrained-model loaded!") return self.start_epoch = checkpoint['epoch'] + 1 self.criterion.load_state_dict(checkpoint['loss']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) self.best_loss = checkpoint['best_loss'] for metric in self.lst_metrics: self.best_metrics[metric] = checkpoint['best_{}'.format(metric)] self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch)) def _print_config(self, params_data=None, params_model=None, params_loss=None, params_optimizers=None, params_lr_scheduler=None, freeze_layers=False, clip_grad_norm_=False): r""" print config into log file """ def __prams_to_str(params: dict): if params == None: return '' row_format = "{:>4}, " * len(params) return row_format.format( *[key + ': ' + str(value) for key, value in params.items()]) self.logger.info('Run id: %s' % (self.run_id)) self.logger.info('Data: ' + __prams_to_str(params_data)) self.logger.info('Model: %s ' % (self.config['model']['name']) + __prams_to_str(params_model)) if freeze_layers: self.logger.info('Freeze layer: %s, at first epoch %d' % (str(self.config['freeze']['layers']), self.config['freeze']['epochs'])) self.logger.info('Loss: %s ' % (self.config['loss']['name']) + __prams_to_str(params_loss)) self.logger.info('Optimizer: %s ' % (self.config['optimizer']['name']) + __prams_to_str(params_optimizers)) if params_lr_scheduler != None: self.logger.info('Lr scheduler: %s ' % (self.config['lr_scheduler']['name']) + __prams_to_str(params_lr_scheduler)) if clip_grad_norm_: self.logger.info('clip_grad_norm_, max_norm: %f' % self.config['clip_grad_norm_']['max_norm'])
class SegmentationTrainer(BaseTrainer): def __init__(self, model, criterion, metrics, optimizer, config, lr_scheduler=None): super().__init__(model, criterion, metrics, optimizer, config) self.lr_scheduler = lr_scheduler self.loss_name = 'supervised_loss' # Metrics # Train self.train_loss = MetricTracker(self.loss_name, self.writer) self.train_metrics = MetricTracker(*self.metric_names, self.writer) # Validation self.valid_loss = MetricTracker(self.loss_name, self.writer) self.valid_metrics = MetricTracker(*self.metric_names, self.writer) # Test self.test_loss = MetricTracker(self.loss_name, self.writer) self.test_metrics = MetricTracker(*self.metric_names, self.writer) if isinstance(self.model, nn.DataParallel): self.criterion = nn.DataParallel(self.criterion) # Resume checkpoint if path is available in config cp_path = self.config['trainer'].get('resume_path') if cp_path: super()._resume_checkpoint() def reset_scheduler(self): self.train_loss.reset() self.train_metrics.reset() self.valid_loss.reset() self.valid_metrics.reset() self.test_loss.reset() self.test_metrics.reset() # if isinstance(self.lr_scheduler, MyReduceLROnPlateau): # self.lr_scheduler.reset() def prepare_train_epoch(self, epoch): self.logger.info('EPOCH: {}'.format(epoch)) self.reset_scheduler() def _train_epoch(self, epoch): self.model.train() self.prepare_train_epoch(epoch) for batch_idx, (data, target, image_name) in enumerate(self.train_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) # For debug model if torch.isnan(loss): super()._save_checkpoint(epoch) self.model.zero_grad() loss.backward() self.optimizer.step() # Update train loss, metrics self.train_loss.update(self.loss_name, loss.item()) for metric in self.metrics: self.train_metrics.update(metric.__name__, metric(output, target), n=output.shape[0]) if batch_idx % self.log_step == 0: self.log_for_step(epoch, batch_idx) if self.save_for_track and (batch_idx % self.save_for_track == 0): save_output(output, image_name, epoch, self.checkpoint_dir) if batch_idx == self.len_epoch: break log = self.train_loss.result() log.update(self.train_metrics.result()) if self.do_validation and (epoch % self.do_validation_interval == 0): val_log = self._valid_epoch(epoch) log.update(val_log) # step lr scheduler if isinstance(self.lr_scheduler, MyReduceLROnPlateau): self.lr_scheduler.step(self.valid_loss.avg(self.loss_name)) return log @staticmethod def get_metric_message(metrics, metric_names): metrics_avg = [metrics.avg(name) for name in metric_names] message_metrics = ', '.join(['{}: {:.6f}'.format(x, y) for x, y in zip(metric_names, metrics_avg)]) return message_metrics def log_for_step(self, epoch, batch_idx): message_loss = 'Train Epoch: {} [{}]/[{}] Dice Loss: {:.6f}'.format(epoch, batch_idx, self.len_epoch, self.train_loss.avg(self.loss_name)) message_metrics = SegmentationTrainer.get_metric_message(self.train_metrics, self.metric_names) self.logger.info(message_loss) self.logger.info(message_metrics) def _valid_epoch(self, epoch, save_result=False, save_for_visual=False): self.model.eval() self.valid_loss.reset() self.valid_metrics.reset() self.logger.info('Validation: ') with torch.no_grad(): for batch_idx, (data, target, image_name) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_loss.update(self.loss_name, loss.item()) for metric in self.metrics: self.valid_metrics.update(metric.__name__, metric(output, target), n=output.shape[0]) if save_result: save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1) if save_for_visual: save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output')) save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target')) if batch_idx % self.log_step == 0: self.logger.debug('{}/{}'.format(batch_idx, len(self.valid_data_loader))) self.logger.debug('{}: {}'.format(self.loss_name, self.valid_loss.avg(self.loss_name))) self.logger.debug(SegmentationTrainer.get_metric_message(self.valid_metrics, self.metric_names)) log = self.valid_loss.result() log.update(self.valid_metrics.result()) val_log = {'val_{}'.format(k): v for k, v in log.items()} return val_log def _test_epoch(self, epoch, save_result=False, save_for_visual=False): self.model.eval() self.test_loss.reset() self.test_metrics.reset() self.logger.info('Test: ') with torch.no_grad(): for batch_idx, (data, target, image_name) in enumerate(self.test_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model(data) loss = self.criterion(output, target) self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test') self.test_loss.update(self.loss_name, loss.item()) for metric in self.metrics: self.test_metrics.update(metric.__name__, metric(output, target), n=output.shape[0]) if save_result: save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1) if save_for_visual: save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output')) save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target')) if batch_idx % self.log_step == 0: self.logger.debug('{}/{}'.format(batch_idx, len(self.test_data_loader))) self.logger.debug('{}: {}'.format(self.loss_name, self.test_loss.avg(self.loss_name))) self.logger.debug(SegmentationTrainer.get_metric_message(self.test_metrics, self.metric_names)) log = self.test_loss.result() log.update(self.test_metrics.result()) test_log = {'test_{}'.format(k): v for k, v in log.items()} return test_log
class LayerwiseTrainer(BaseTrainer): """ Trainer """ def __init__(self, model: DepthwiseStudent, criterions, metric_ftns, optimizer, config, train_data_loader, valid_data_loader=None, lr_scheduler=None, weight_scheduler=None): super().__init__(model, None, metric_ftns, optimizer, config) self.config = config self.train_data_loader = train_data_loader self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.do_validation_interval = self.config['trainer'][ 'do_validation_interval'] self.lr_scheduler = lr_scheduler self.weight_scheduler = weight_scheduler self.log_step = config['trainer']['log_step'] if "len_epoch" in self.config['trainer']: # iteration-based training self.train_data_loader = inf_loop(train_data_loader) self.len_epoch = self.config['trainer']['len_epoch'] else: # epoch-based training self.len_epoch = len(self.train_data_loader) # Metrics # Train self.train_metrics = MetricTracker( 'loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.train_iou_metrics = CityscapesMetricTracker(writer=self.writer) self.train_teacher_iou_metrics = CityscapesMetricTracker( writer=self.writer) # Valid self.valid_metrics = MetricTracker( 'loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_iou_metrics = CityscapesMetricTracker(writer=self.writer) # Test self.test_metrics = MetricTracker( 'loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss', *[m.__name__ for m in self.metric_ftns], *['teacher_' + m.__name__ for m in self.metric_ftns], writer=self.writer, ) self.test_iou_metrics = CityscapesMetricTracker(writer=self.writer) # Tracker for early stop if val miou doesn't increase self.val_iou_tracker = EarlyStopTracker('best', 'max', 0.01, 'rel') # Only used list of criterions and remove the unused property self.criterions = criterions self.criterions = nn.ModuleList(self.criterions).to(self.device) if isinstance(self.model, nn.DataParallel): self.criterions = nn.DataParallel(self.criterions) del self.criterion # Resume checkpoint if path is available in config if 'resume_path' in self.config['trainer']: self.resume(self.config['trainer']['resume_path']) def prepare_train_epoch(self, epoch, config=None): """ Prepare before training an epoch i.e. prune new layer, unfreeze some layers, create new optimizer .... :param epoch: int - indicate which epoch the trainer's in :param config: a config object that contain pruning_plan, hint, unfreeze information :return: """ # if the config is not set (training normaly, then set config to current trainer config) # if the config is set (in case you're resuming a checkpoint) then use saved config to replace # layers in student so that it would have identical archecture with saved checkpoint if config is None: config = self.config # reset_scheduler self.reset_scheduler() # there isn't any layer that would be replaced or unfreeze or set as hint then unfreeze # the whole network if (epoch == 1) and ((len(config['pruning']['pruning_plan']) + len(config['pruning']['hint']) + len(config['pruning']['unfreeze'])) == 0): self.logger.debug( 'Train a student with identical architecture with teacher') # unfreeze for param in self.model.student.parameters(): param.requires_grad = True # debug self.logger.info(self.model.dump_trainable_params()) # create optimizer for the network self.create_new_optimizer() # ignore all below stuff return # Check if there is any layer that would any update in current epoch # list of epochs that would have an update on student networks epochs = list( map( lambda x: x['epoch'], config['pruning']['pruning_plan'] + config['pruning']['hint'] + config['pruning']['unfreeze'])) # if there isn't any update then simply return if epoch not in epochs: self.logger.info('EPOCH: ' + str(epoch)) self.logger.info('There is no update ...') return # layers that would be replaced by depthwise separable conv replaced_layers = list( filter(lambda x: x['epoch'] == epoch, config['pruning']['pruning_plan'])) # layers which outputs will be used as loss hint_layers = list( map( lambda x: x['name'], filter(lambda x: x['epoch'] == epoch, config['pruning']['hint']))) # layers that would be trained in this epoch unfreeze_layers = list( map( lambda x: x['name'], filter(lambda x: x['epoch'] == epoch, config['pruning']['unfreeze']))) self.logger.info('EPOCH: ' + str(epoch)) self.logger.info('Replaced layers: ' + str(replaced_layers)) self.logger.info('Hint layers: ' + str(hint_layers)) self.logger.info('Unfreeze layers: ' + str(unfreeze_layers)) # Avoid error when loading deprecate checkpoint which don't have 'args' in config.pruning if 'args' in config['pruning']: kwargs = config['pruning']['args'] else: self.logger.warning('Using deprecate checkpoint...') kwargs = config['pruning']['pruner'] self.model.replace( replaced_layers, **kwargs) # replace those layers with depthwise separable conv self.model.register_hint_layers( hint_layers ) # assign which layers output would be used as hint loss self.model.unfreeze(unfreeze_layers) # unfreeze chosen layers if epoch == 1: self.create_new_optimizer( ) # create new optimizer to remove the effect of momentum else: self.update_optimizer( list( filter(lambda x: x['epoch'] == epoch, config['pruning']['unfreeze']))) self.logger.info(self.model.dump_trainable_params()) self.logger.info(self.model.dump_student_teacher_blocks_info()) def update_optimizer(self, unfreeze_config): """ Update param groups for optimizer with unfreezed layers of this epoch :param unfreeze_config - list of arg. Each arg is the dictionary with following format: {'name': 'layer1', 'epoch':1, 'lr'(optional): 0.01} return: """ if len(unfreeze_config) > 0: self.logger.debug('Updating optimizer for new layer') for config in unfreeze_config: layer_name = config['name'] # layer that will be unfreezed self.logger.debug( 'Add parameters of layer: {} to optimizer'.format(layer_name)) layer = self.model.get_block( layer_name, self.model.student) # actual layer i.e. nn.Module obj optimizer_arg = self.config['optimizer'][ 'args'] # default args for optimizer # we can also specify layerwise learning ! if "lr" in config: optimizer_arg['lr'] = config['lr'] # add unfreezed layer's parameters to optimizer self.optimizer.add_param_group({ 'params': layer.parameters(), **optimizer_arg }) def create_new_optimizer(self): """ Create new optimizer if trainer is in epoch 1 otherwise just run update optimizer """ # Create new optimizer self.logger.debug('Creating new optimizer ...') self.optimizer = self.config.init_obj( 'optimizer', optim_module, list( filter(lambda x: x.requires_grad, self.model.student.parameters()))) self.lr_scheduler = self.config.init_obj('lr_scheduler', optim_module.lr_scheduler, self.optimizer) def reset_scheduler(self): """ reset all schedulers, metrics, trackers, etc when unfreeze new layer :return: """ self.weight_scheduler.reset() # weight between loss self.val_iou_tracker.reset() # verify val iou would increase each time self.train_metrics.reset() # metrics for loss,... in training phase self.valid_metrics.reset() # metrics for loss,... in validating phase self.train_iou_metrics.reset() # train iou of student self.valid_iou_metrics.reset() # val iou of student self.train_teacher_iou_metrics.reset() # train iou of teacher if isinstance(self.lr_scheduler, MyReduceLROnPlateau): self.lr_scheduler.reset() def _train_epoch(self, epoch): """ Training logic for 1 epoch """ # Prepare the network i.e. unfreezed new layers, replaced new layer with depthwise separable conv, ... self.prepare_train_epoch(epoch) # reset # FIXME: # as the teacher network contain batchnorm layer and our resources are limited to train with # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having # small batch size # self.model.train() self.train_iou_metrics.reset() self.train_teacher_iou_metrics.reset() self._clean_cache() for batch_idx, (data, target) in enumerate(self.train_data_loader): data, target = data.to(self.device), target.to(self.device) output_st, output_tc = self.model(data) supervised_loss = self.criterions[0]( output_st, target) / self.accumulation_steps kd_loss = self.criterions[1](output_st, output_tc) / self.accumulation_steps teacher_loss = self.criterions[0](output_tc, target) # for comparision hint_loss = reduce( lambda acc, elem: acc + self.criterions[2](elem[0], elem[1]), zip(self.model.student_hidden_outputs, self.model.teacher_hidden_outputs), 0) / self.accumulation_steps # Only use hint loss loss = hint_loss loss.backward() if batch_idx % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) # update metrics self.train_metrics.update('loss', loss.item() * self.accumulation_steps) self.train_metrics.update( 'supervised_loss', supervised_loss.item() * self.accumulation_steps) self.train_metrics.update('kd_loss', kd_loss.item() * self.accumulation_steps) self.train_metrics.update( 'hint_loss', hint_loss.item() * self.accumulation_steps) self.train_metrics.update('teacher_loss', teacher_loss.item()) self.train_iou_metrics.update(output_st.detach().cpu(), target.cpu()) self.train_teacher_iou_metrics.update(output_tc.cpu(), target.cpu()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output_st, target)) if batch_idx % self.log_step == 0: # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) # st_masks = visualize.viz_pred_cityscapes(output_st) # tc_masks = visualize.viz_pred_cityscapes(output_tc) # self.writer.add_image('st_pred', make_grid(st_masks, nrow=8, normalize=False)) # self.writer.add_image('tc_pred', make_grid(tc_masks, nrow=8, normalize=False)) self.logger.info( 'Train Epoch: {} [{}]/[{}] Loss: {:.6f} mIoU: {:.6f} Teacher mIoU: {:.6f} Supervised Loss: {:.6f} ' 'Knowledge Distillation loss: ' '{:.6f} Hint Loss: {:.6f} Teacher Loss: {:.6f}'.format( epoch, batch_idx, self.len_epoch, self.train_metrics.avg('loss'), self.train_iou_metrics.get_iou(), self.train_teacher_iou_metrics.get_iou(), self.train_metrics.avg('supervised_loss'), self.train_metrics.avg('kd_loss'), self.train_metrics.avg('hint_loss'), self.train_metrics.avg('teacher_loss'), )) if batch_idx == self.len_epoch: break log = self.train_metrics.result() log.update( {'train_teacher_mIoU': self.train_teacher_iou_metrics.get_iou()}) log.update({'train_student_mIoU': self.train_iou_metrics.get_iou()}) if self.do_validation and ( (epoch % self.config["trainer"]["do_validation_interval"]) == 0): val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) log.update(**{'val_mIoU': self.valid_iou_metrics.get_iou()}) self.val_iou_tracker.update(self.valid_iou_metrics.get_iou()) self._teacher_student_iou_gap = self.train_teacher_iou_metrics.get_iou( ) - self.train_iou_metrics.get_iou() # step lr scheduler if (self.lr_scheduler is not None) and (not isinstance( self.lr_scheduler, MyOneCycleLR)): if isinstance(self.lr_scheduler, MyReduceLROnPlateau): self.lr_scheduler.step(self.train_metrics.avg('loss')) else: self.lr_scheduler.step() self.logger.debug('stepped lr') for param_group in self.optimizer.param_groups: self.logger.debug(param_group['lr']) # anneal weight between losses self.weight_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self._clean_cache() # FIXME: # as the teacher network contain batchnorm layer and our resources are limited to train with # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having # small batch size # self.model.eval() self.model.save_hidden = False # stop saving hidden output self.valid_metrics.reset() self.valid_iou_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output = self.model.inference(data) supervised_loss = self.criterions[0](output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') self.valid_metrics.update('supervised_loss', supervised_loss.item()) self.valid_iou_metrics.update(output.detach().cpu(), target) self.logger.debug( str(batch_idx) + " : " + str(self.valid_iou_metrics.get_iou())) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target)) result = self.valid_metrics.result() result['mIoU'] = self.valid_iou_metrics.get_iou() return result def _test_epoch(self, epoch): # cleaning up memory self._clean_cache() # self.model.eval() self.model.save_hidden = False self.model.cpu() self.model.student.to(self.device) # prepare before running submission self.test_metrics.reset() self.test_iou_metrics.reset() args = self.config['test']['args'] save_4_sm = self.config['submission']['save_output'] path_output = self.config['submission']['path_output'] if save_4_sm and not os.path.exists(path_output): os.mkdir(path_output) n_samples = len(self.valid_data_loader) with torch.no_grad(): for batch_idx, (img_name, data, target) in enumerate(self.valid_data_loader): self.logger.info('{}/{}'.format(batch_idx, n_samples)) data, target = data.to(self.device), target.to(self.device) output = self.model.inference_test(data, args) if save_4_sm: self.save_for_submission(output, img_name[0]) supervised_loss = self.criterions[0](output, target) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'test') self.test_metrics.update('supervised_loss', supervised_loss.item()) self.test_iou_metrics.update(output.detach().cpu(), target) for met in self.metric_ftns: self.test_metrics.update(met.__name__, met(output, target)) result = self.test_metrics.result() result['mIoU'] = self.test_iou_metrics.get_iou() return result def save_for_submission(self, output, image_name, img_type=np.uint8): args = self.config['submission'] path_output = args['path_output'] image_save = '{}.{}'.format(image_name, args['ext']) path_save = os.path.join(path_output, image_save) result = torch.argmax(output, dim=1) result_mapped = self.re_map_for_submission(result) if output.size()[0] == 1: result_mapped = result_mapped[0] save_image(result_mapped.cpu().numpy().astype(img_type), path_save) print('Saved output of test data: {}'.format(image_save)) def re_map_for_submission(self, output): mapping = self.valid_data_loader.dataset.id_to_trainid cp_output = torch.zeros(output.size()) for k, v in mapping.items(): cp_output[output == v] = k return cp_output def _clean_cache(self): self.model.student_hidden_outputs, self.model.teacher_hidden_outputs = list( ), list() gc.collect() torch.cuda.empty_cache() def resume(self, checkpoint_path): self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path)) checkpoint = torch.load(checkpoint_path, map_location=torch.device('cpu')) self.start_epoch = checkpoint['epoch'] + 1 self.mnt_best = checkpoint['monitor_best'] config = checkpoint['config'] # config of checkpoint epoch = checkpoint['epoch'] # stopped epoch # load model state from checkpoint # first, align the network by replacing depthwise separable for student for i in range(1, epoch + 1): self.prepare_train_epoch(i, config) # load weight forgiving_state_restore(self.model, checkpoint['state_dict']) self.logger.info("Loaded model's state dict") # load optimizer state from checkpoint only when optimizer type is not changed. if checkpoint['config']['optimizer']['type'] != self.config[ 'optimizer']['type']: self.logger.warning( "Warning: Optimizer type given in config file is different from that of checkpoint. " "Optimizer parameters not being resumed.") else: self.optimizer.load_state_dict(checkpoint['optimizer']) self.logger.info("Loaded optimizer state dict")
class Trainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.track_loss = ['loss', 'recon', 'kld', 'lmse', 'contrast', 'cycle', 'cycle_mse', 'cycle_ce', 'pseudo', 'klc'] self.train_metrics = MetricTracker(*self.track_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker(*self.track_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.pitch_map = {i: n for n, i in enumerate(data_loader.dataset.pitch_map)} self.dynamic_map = {i: n for n, i in enumerate(data_loader.dataset.dynamic_map)} self.pitchclass_map = {i: n for n, i in enumerate(data_loader.dataset.pitchclass_map)} self.tf_map = {v: data_loader.dataset.family_map[k] for k,v in data_loader.dataset.instrument_map.items()} self.plot_step = 25 self.recon_sample = np.random.choice(valid_data_loader.sampler.indices, size=10, replace=False) pitches = np.random.choice(82, size=len(self.recon_sample)) self.sample_to_pitch = {k: v for k, v in zip(self.recon_sample, pitches)} self.spec_ext = ExtractSpectrogram(sr=SR, n_fft=NFFT, hop_length=HOP, n_mels=NMEL, mode='mel') self.x_max, self.x_min = 9.7666, -36.0437 self.init_temp = self.model.temperature self.min_temp = self.model.min_temperature self.decay_rate = self.model.decay_rate self.pseudo_train = config['trainer']['pseudo_train'] self.labeled = config['trainer']['labeled'] self.labeled_sample = np.random.choice(data_loader.sampler.indices, size=int(len(data_loader.sampler.indices) * self.labeled), replace=False) self.freeze_encoder = config['trainer']['freeze_encoder'] self.pitch_shift = config['trainer']['pitch_shift'] def data_transform(self, x, **kwargs): def get_idx(at_time=0.2, pitch_shift=2): compensate_duration = 0.05 load_duration = at_time + compensate_duration # add 0.05s more after the targeted time instant # desired_idx = int(at_time * SR) if pitch_shift != 0: pitch_shift = np.random.randint(-pitch_shift, pitch_shift) # shift = -2 scale = 2. ** (pitch_shift / 12.) idx_comp = int(compensate_duration * scale**(-1) * SR / HOP) # the corresponding number of indices to be compensated if pitch_shift < 0: n_sample = int(scale**(-1) * load_duration * SR) # n_sample = int(scale**(-1) * load_duration * SR) # assert n_sample > desired_idx desired_idx = int((load_duration * SR) / HOP) - idx_comp if pitch_shift >= 0: n_sample = int(load_duration * SR) desired_idx = int((scale**(-1) * n_sample) / HOP) - idx_comp return pitch_shift, n_sample, desired_idx shift, n_sample, desired_idx = get_idx(**kwargs) x = LoadNpArray(n_sample=n_sample)(x) x = PitchShift(shift=shift)(x) x = ToTensor()(x) x = self.spec_ext(x) x = LogCompress()(x) x = Clipping(clip_min=self.x_min, clip_max=self.x_max)(x) x = MinMaxNorm(x_min=self.x_min, x_max=self.x_max)(x) x = x[:, :, desired_idx] return x, shift def get_gumb_temp(self, epoch, init_temp, min_temp, decay_rate): temp = np.maximum(init_temp * np.exp(-decay_rate * epoch), min_temp) return temp def get_ps_label(self, yp, ps): y_shift = torch.from_numpy(np.array(ps)).unsqueeze(-1).to(self.device) y_ps = yp + y_shift mask_l = torch.where(y_ps >= 0, torch.ones_like(y_ps), torch.zeros_like(y_ps)) mask_u = torch.where(y_ps <= 81, torch.ones_like(y_ps), torch.zeros_like(y_ps)) mask = mask_l * mask_u y_ps *= mask if self.pitch_shift == 0: assert (y_ps == yp).sum() == len(y_ps) return yp, y_ps, mask.float(), torch.ones_like(yp) def get_pseudo_label(self, logit, supervised_idx, pitch_label, pitch_shift): '''Algorithm for creating pseudo labels for pitch-shifted samples ''' supervised = True if len(supervised_idx) > 0 else False # initialize masks for both original and pitch-shiftedd samples m, m_ps = torch.zeros_like(pitch_label).float(), torch.zeros_like(pitch_label).float() '''Original samples''' # pseudo labels are defined from the inferred catogrical distribution y_pseudo = torch.argmax(logit, dim=-1, keepdim=True) if supervised: supervised_idx = supervised_idx.long() # replace pseudo with supervised labels # NOTE: psuedo labels become true if supervised portion is 100% y_pseudo[supervised_idx] = pitch_label[supervised_idx] # only the supervised indices are un-masked for the orignal samples m[supervised_idx] = 1 # cross-entropy induced by pseudo labels will be masked '''Pitch-shifted samples''' # exploit pseudo labels if if if self.pseudo_train: m_ps += 1 # un-mask supervised labels regardlessly if supervised: m_ps[supervised_idx] = 1 if m_ps.gt(1).any(): print("mask has entry larger than 1 before being multiplied with exclusion mask") # further mask the out-of-range pitches based on pseudo labels _, y_ps_pseudo, m_ps_ext, _ = self.get_ps_label(y_pseudo, pitch_shift) m_ps *= m_ps_ext if m_ps.gt(1).any(): print("mask has entry larger than 1 AFTER being multiplied with exclusion mask") return y_pseudo, y_ps_pseudo, m, m_ps def get_data(self, x, n_semitone=2): for i, x_i in enumerate(x): x_ps, ps = self.data_transform(x_i, at_time=0.2, pitch_shift=n_semitone) x_ori, _ = self.data_transform(x_i, at_time=0.2, pitch_shift=0) if i == 0: ps_cat = [ps] x_ps_cat = x_ps x_ori_cat = x_ori else: ps_cat.append(ps) x_ps_cat= torch.cat([x_ps_cat, x_ps]) x_ori_cat = torch.cat([x_ori_cat, x_ori]) return x_ori_cat, x_ps_cat, ps_cat def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() torch.manual_seed(1111) for batch_idx, (x, idx, y) in enumerate(self.data_loader): supervised_idx = torch.from_numpy(np.array([i for i, v in enumerate(idx.numpy()) if v in self.labeled_sample], dtype='float')).to(self.device) y = torch.stack(y, dim=1).to(self.device) yp = y[:, 1:2] x1, x2, ps_cat = self.get_data(x, n_semitone=self.pitch_shift) self.optimizer.zero_grad() x1_hat, h1, mu1, logvar1, z1_t, z1_p, logits1, prob1 = self.model(x1, yp) if self.model.gumbel: y_pseudo, y_ps_pseudo, m, m_ps = self.get_pseudo_label(logits1, supervised_idx, yp, ps_cat) else: y_pseudo, y_ps_pseudo, m, m_ps = self.get_ps_label(yp, ps_cat) x2_hat, h2, mu2, logvar2, z2_t, z2_p, logits2, prob2 = self.model(x2, y_ps_pseudo) # con_loss = self.nt_xent_criterion(mu1, mu2) dict_loss = self.criterion(self.model, self.pseudo_train, self.device, x1, x1_hat, x2, x2_hat, mu1, logvar1, z1_t, z1_p, mu2, logvar2, z2_t, z2_p, logits1=logits1, logits2=logits2, prob1=prob1, prob2=prob2, epoch=epoch, mask=m_ps.float(), mask_y=m.float(), y=y_pseudo.squeeze(-1), y_ps=y_ps_pseudo.squeeze(-1)) for k, v in dict_loss.items(): if torch.isnan(v): print(k) for name, p in self.model.named_parameters(): if torch.isnan(p).any(): print(name) if dict_loss['cycle'].requires_grad: dict_loss['loss'].backward(retain_graph=True) else: dict_loss['loss'].backward() self.optimizer.step() pre_tim_op = copy.deepcopy(list(self.model.timbre_encoder.parameters())) pre_pitch_op = copy.deepcopy(list(self.model.pitch_encoder.parameters())) if dict_loss['cycle'].requires_grad: self.optimizer.zero_grad() dict_loss['cycle'].backward() if self.freeze_encoder: for i, param in enumerate(self.model.timbre_encoder.parameters()): param.grad[:] = 0 for i, param in enumerate(self.model.pitch_encoder.parameters()): if param.grad is not None: param.grad[:] = 0 self.optimizer.step() for name, p in self.model.named_parameters(): if torch.isnan(p).any(): print(name) if self.model.gumbel: temp = self.get_gumb_temp(epoch, self.init_temp, self.min_temp, self.decay_rate) self.model.set_temperature(temp) else: temp = 0 for track, output in zip(self.track_loss, dict_loss): assert track == output log_val = dict_loss[track].item() self.train_metrics.update(track, log_val) if batch_idx == self.len_epoch: break if batch_idx == 0: idx_cat = idx zt_cat, zp_cat = z1_t, z1_p yt_cat, yp_cat, yf_cat, yc_cat, yd_cat = y[:, 0:1], y[:, 1:2], y[:, -1:], y[:, 2:3], y[:, 3:4] x_cat, x_hat_cat = x1, x1_hat mu1_cat, logvar1_cat = mu1, logvar1 mu2_cat, logvar2_cat = mu2, logvar2 if prob1 is not None: yp_hat_cat = torch.argmax(prob1, dim=-1, keepdim=True) else: yp_hat_cat = None else: idx_cat = torch.cat([idx_cat, idx]) zt_cat, zp_cat = torch.cat([zt_cat, z1_t]), torch.cat([zp_cat, z1_p]) yt_cat = torch.cat([yt_cat, y[:, 0:1]], dim=0) yp_cat = torch.cat([yp_cat, y[:, 1:2]], dim=0) yf_cat = torch.cat([yf_cat, y[:, -1:]], dim=0) yc_cat = torch.cat([yc_cat, y[:, 2:3]], dim=0) yd_cat = torch.cat([yd_cat, y[:, 3:4]], dim=0) mu1_cat, logvar1_cat = torch.cat([mu1_cat, mu1]), torch.cat([logvar1_cat, logvar1]) mu2_cat, logvar2_cat = torch.cat([mu2_cat, mu2]), torch.cat([logvar2_cat, logvar2]) x_hat_cat = torch.cat([x_hat_cat, x1_hat]) x_cat = torch.cat([x_cat, x1]) if prob1 is not None: yp_hat_cat = torch.cat([yp_hat_cat, torch.argmax(prob1, dim=-1, keepdim=True)]) else: yp_hat_cat = None self.writer.set_step(epoch, 'train') for track, output in zip(self.track_loss, dict_loss): assert track == output self.writer.add_scalar(track, self.train_metrics.avg(track)) for met in self.metric_ftns: # if met.__name__ == 'cluster_var': # self.train_metrics.update(met.__name__, met(mu1_cat.cpu(), yp_cat.cpu())) # if met.__name__ == 'kl_gauss': # self.train_metrics.update(met.__name__, met(mu1_cat, logvar1_cat, mu2_cat, logvar2_cat).item()) if met.__name__ == 'f1' and yp_hat_cat is not None: self.train_metrics.update(met.__name__, met(yp_hat_cat, yp_cat, n_class=82).item()) if met.__name__ == 'cluster_acc' and yp_hat_cat is not None: self.train_metrics.update(met.__name__, met(yp_hat_cat, yp_cat)) if met.__name__ == 'nmi' and yp_hat_cat is not None: self.train_metrics.update(met.__name__, met(yp_hat_cat, yp_cat)) self.writer.add_scalar(met.__name__, self.train_metrics.avg(met.__name__)) log = self.train_metrics.result() if epoch % self.plot_step == 0: yt_cat = yt_cat.squeeze(-1).detach().cpu().numpy() yp_cat = yp_cat.squeeze(-1).detach().cpu().numpy() yf_cat = yf_cat.squeeze(-1).detach().cpu().numpy() yc_cat = yc_cat.squeeze(-1).detach().cpu().numpy() yd_cat = yd_cat.squeeze(-1).detach().cpu().numpy() zt_2d = TSNE(n_components=2).fit_transform(mu1_cat.cpu().data.numpy()) fig, ax = plt.subplots(2, 4, figsize=(4*5, 2*5)) def plot_and_color(data, ax, label_map, labels, colors=None): n_class = len(np.unique(labels)) if colors is not None: assert n_class == len(colors) else: random.seed(1111) colors = ['#'+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(n_class)] assert len(label_map.items()) == n_class for k, v in label_map.items(): target_data = data[labels == v] ax.scatter(target_data[:, 0], target_data[:, 1], c=colors[v], label=k, alpha=0.7) plot_and_color(zt_2d, ax[0][0], INSTRUMENT_MAP, yt_cat, colors=INSTRUMENT_COLORS) plot_and_color(zt_2d, ax[0][2], self.pitch_map, yp_cat, colors=PITCH_COLORS) plot_and_color(zt_2d, ax[0][1], FAMILY_MAP, yf_cat, colors=None) plot_and_color(zt_2d, ax[0][3], self.dynamic_map, yd_cat, colors=None) ax[1][1].imshow(self.model.emb.weight.cpu().data.numpy().T, aspect='auto', origin='lower') else: fig = None ax = None if self.do_validation: val_log = self._valid_epoch(epoch, fig, ax) log.update(**{'val_'+k : v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() log.update({"gumbel_temp": temp}) return log def _valid_epoch(self, epoch, fig=None, ax=None): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() torch.manual_seed(1111) with torch.no_grad(): for batch_idx, (x, idx, y) in enumerate(self.valid_data_loader): y = torch.stack(y, dim=1).to(self.device) yp = y[:, 1:2] x1, x2, ps_cat = self.get_data(x, n_semitone=self.pitch_shift) x1_hat, h1, mu1, logvar1, z1_t, z1_p, logits1, prob1 = self.model(x1, yp) y_pseudo, y_ps_pseudo, m, m_ps = self.get_ps_label(yp, ps_cat) x2_hat, h2, mu2, logvar2, z2_t, z2_p, logits2, prob2 = self.model(x2, y_ps_pseudo) x1_hat_swap, _ = self.model.decode(z2_t, z1_p) x2_hat_swap, _ = self.model.decode(z1_t, z2_p) dict_loss = self.criterion(self.model, self.pseudo_train, self.device, x1, x1_hat, x2, x2_hat, mu1, logvar1, z1_t, z1_p, mu2, logvar2, z2_t, z2_p, logits1=logits1, logits2=logits2, prob1=prob1, prob2=prob2, epoch=epoch, mask=m_ps.float(), mask_y=m.float(), y=y_pseudo.squeeze(-1), y_ps=y_ps_pseudo.squeeze(-1)) for track, output in zip(self.track_loss, dict_loss): assert track == output log_val = dict_loss[track].item() self.valid_metrics.update(track, log_val) if batch_idx == 0: idx_cat = idx zt_cat, zp_cat = z1_t, z1_p yt_cat, yp_cat, yf_cat, yc_cat, yd_cat = y[:, 0:1], y[:, 1:2], y[:, -1:], y[:, 2:3], y[:, 3:4] x_cat, x_hat_cat = x1, x1_hat mu1_cat, logvar1_cat = mu1, logvar1 mu2_cat, logvar2_cat = mu2, logvar2 h_cat = h1 if prob1 is not None: yp_hat_cat = torch.argmax(prob1, dim=-1, keepdim=True) else: yp_hat_cat = None else: idx_cat = torch.cat([idx_cat, idx]) zt_cat, zp_cat = torch.cat([zt_cat, z1_t]), torch.cat([zp_cat, z1_p]) yt_cat = torch.cat([yt_cat, y[:, 0:1]], dim=0) yp_cat = torch.cat([yp_cat, y[:, 1:2]], dim=0) yf_cat = torch.cat([yf_cat, y[:, -1:]], dim=0) yc_cat = torch.cat([yc_cat, y[:, 2:3]], dim=0) yd_cat = torch.cat([yd_cat, y[:, 3:4]], dim=0) mu1_cat, logvar1_cat = torch.cat([mu1_cat, mu1]), torch.cat([logvar1_cat, logvar1]) mu2_cat, logvar2_cat = torch.cat([mu2_cat, mu2]), torch.cat([logvar2_cat, logvar2]) x_hat_cat = torch.cat([x_hat_cat, x1_hat]) x_cat = torch.cat([x_cat, x1]) h_cat = torch.cat([h_cat, h1]) if prob1 is not None: yp_hat_cat = torch.cat([yp_hat_cat, torch.argmax(prob1, dim=-1, keepdim=True)]) else: yp_hat_cat = None self.writer.set_step(epoch, 'valid') for track, output in zip(self.track_loss, dict_loss): assert track == output self.writer.add_scalar(track, self.valid_metrics.avg(track)) for met in self.metric_ftns: # if met.__name__ == 'cluster_var': # self.valid_metrics.update(met.__name__, met(mu1_cat.cpu(), yp_cat.cpu())) # if met.__name__ == 'kl_gauss': # self.valid_metrics.update(met.__name__, met(mu1_cat, logvar1_cat, mu2_cat, logvar2_cat).item()) if met.__name__ == 'f1' and yp_hat_cat is not None: self.valid_metrics.update(met.__name__, met(yp_hat_cat, yp_cat, n_class=82).item()) if met.__name__ == 'cluster_acc' and yp_hat_cat is not None: self.valid_metrics.update(met.__name__, met(yp_hat_cat, yp_cat)) if met.__name__ == 'nmi' and yp_hat_cat is not None: self.valid_metrics.update(met.__name__, met(yp_hat_cat, yp_cat)) self.writer.add_scalar(met.__name__, self.valid_metrics.avg(met.__name__)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') if fig is not None: idx_cat = idx_cat.squeeze(-1).cpu().data.numpy() target_idx = np.array([np.where(idx_cat == i)[0] for i in self.recon_sample]) non_empty_idx = np.vstack([(n, i) for n, i in enumerate(target_idx) if len(i) == 1]) target_idx = np.vstack([i for i in target_idx if len(i) == 1])[:,0] self.recon_sample = [self.recon_sample[i[0]] for i in non_empty_idx] # target_idx = np.array([np.where(idx_cat == i)[0] for i in self.recon_sample])[:,0] target_pitch = np.array([self.sample_to_pitch[i] for i in self.recon_sample]) origin = x_cat.cpu().data.numpy()[target_idx] output = x_hat_cat.cpu().data.numpy()[target_idx] h_cat = h_cat.cpu().data.numpy()[target_idx] zt_cat = zt_cat[target_idx] zp_target = self.model.emb.weight[target_pitch] if self.model.use_hp: zp_target = self.model.project_harmonic(zp_target) output_pswap = self.model.decode(zt_cat, zp_target)[0] output_pswap = output_pswap.cpu().data.numpy() for m, (i, j, k, l) in enumerate(zip(origin, output, h_cat, output_pswap)): tmp= np.vstack([i, j]) tmp_swap = np.vstack([i, l]) if self.model.decoding == 'sf': tmp_h = np.vstack([i, k]) if m == 0: pair = tmp pair_swap = tmp_swap if self.model.decoding == 'sf': pair_h = tmp_h else: pair = np.vstack([pair, tmp]) pair_swap = np.vstack([pair_swap, tmp_swap]) if self.model.decoding == 'sf': pair_h = np.vstack([pair_h, tmp_h]) ax[1][2].imshow(pair.T, aspect='auto', origin='lower', vmin=0, vmax=1) for l in range(1, 2*len(self.recon_sample), 2): ax[1][2].axvline(x=l+0.5, lw=1.5, c='r') ax[1][3].imshow(pair_swap.T, aspect='auto', origin='lower', vmin=0, vmax=1) for l in range(1, 2*len(self.recon_sample), 2): ax[1][3].axvline(x=l+0.5, lw=1.5, c='r') if self.model.decoding == 'sf': ax[1][0].imshow(pair_h.T, aspect='auto', origin='lower', vmin=0, vmax=1) for l in range(1, 2*len(self.recon_sample), 2): ax[1][0].axvline(x=l+0.5, lw=1.5, c='r') self.writer.set_step(epoch, 'train') self.writer.add_figure('tsne', fig) return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class ClassifierTrainer(BaseTrainer): """ Trainer class """ def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None): super().__init__(model, criterion, metric_ftns, optimizer, config) self.config = config self.data_loader = data_loader if len_epoch is None: # epoch-based training self.len_epoch = len(self.data_loader) else: # iteration-based training self.data_loader = inf_loop(data_loader) self.len_epoch = len_epoch self.valid_data_loader = valid_data_loader self.do_validation = self.valid_data_loader is not None self.lr_scheduler = lr_scheduler self.log_step = int(np.sqrt(data_loader.batch_size)) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() torch.manual_seed(1111) for batch_idx, (x, idx, gt) in enumerate(self.data_loader): x = x.squeeze(1) gt = torch.stack(gt, dim=1).to(self.device) self.optimizer.zero_grad() if self.model.target == 'instrument': y = gt[:, 0] elif self.model.target == 'pitch': y = gt[:, 1] output = self.model(x) loss = self.criterion(output, y) loss.backward() self.optimizer.step() self.train_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output, y)) if batch_idx == self.len_epoch: break log = self.train_metrics.result() self.writer.set_step(epoch, 'train') self.writer.add_scalar('loss', self.train_metrics.avg('loss')) for met in self.metric_ftns: self.writer.add_scalar(met.__name__, self.train_metrics.avg(met.__name__)) if self.do_validation: val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if self.lr_scheduler is not None: self.lr_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() torch.manual_seed(1111) with torch.no_grad(): for batch_idx, (x, idx, gt) in enumerate(self.valid_data_loader): x = x.squeeze(1) gt = torch.stack(gt, dim=1).to(self.device) if self.model.target == 'instrument': y = gt[:, 0] elif self.model.target == 'pitch': y = gt[:, 1] output = self.model(x) loss = self.criterion(output, y) self.valid_metrics.update('loss', loss.item()) for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, y)) self.writer.set_step(epoch, 'valid') self.writer.add_scalar('loss', self.valid_metrics.avg('loss')) for met in self.metric_ftns: self.writer.add_scalar(met.__name__, self.valid_metrics.avg(met.__name__)) # add histogram of model parameters to the tensorboard for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') return self.valid_metrics.result() def _progress(self, batch_idx): base = '[{}/{} ({:.0f}%)]' if hasattr(self.data_loader, 'n_samples'): current = batch_idx * self.data_loader.batch_size total = self.data_loader.n_samples else: current = batch_idx total = self.len_epoch return base.format(current, total, 100.0 * current / total)
class BaseTrainer: """ Base class for all trainers """ def __init__(self, model, criterion, metric_ftns, optimizer, lr_scheduler, config, trainloader, validloader=None, len_epoch=None): self.config = config self.logger = config.get_logger('trainer', config['trainer']['verbosity']) self.trainloader = trainloader self.validloader = validloader if len_epoch is None: # epoch-based training self.len_epoch = len(self.trainloader) else: # iteration-based training self.trainloader = inf_loop(trainloader) self.len_epoch = len_epoch # setup GPU device if available, move model into configured device n_gpu_use = torch.cuda.device_count() self.device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu') self.model = model.to(self.device) self.model = torch.nn.DataParallel(model) self.criterion = criterion self.metric_ftns = metric_ftns self.optimizer = optimizer self.lr_scheduler = lr_scheduler cfg_trainer = config['trainer'] self.epochs = cfg_trainer['epochs'] self.log_step = cfg_trainer['log_step'] self.save_period = cfg_trainer['save_period'] self.monitor = cfg_trainer.get('monitor', 'off') self.start_epoch = 1 self.checkpoint_dir = config.save_dir # configuration to monitor model performance and save best if self.monitor == 'off': self.mnt_mode = 'off' self.mnt_best = 0 else: self.mnt_mode, self.mnt_metric = self.monitor.split() assert self.mnt_mode in ['min', 'max'] self.mnt_best = inf if self.mnt_mode == 'min' else -inf self.early_stop = cfg_trainer.get('early_stop', inf) # setup visualization writer instance self.writer = TensorboardWriter(config.log_dir, self.logger, cfg_trainer['tensorboard']) self.train_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer) if config.resume is not None: self._resume_checkpoint(config.resume) @abstractmethod def _train_step(self, batch): """ Training logic for a step :param batch: batch of current step :return: loss: torch Variable with map for backwarding mets: metrics computed between output and target, dict """ raise NotImplementedError @abstractmethod def _valid_step(self, batch): """ Valid logic for a step :param batch: batch of current step :return: loss: torch Variable without map mets: metrics computed between output and target, dict """ raise NotImplementedError def _train_epoch(self, epoch): """ Training logic for an epoch :param epoch: Integer, current training epoch. :return: A log that contains average loss and metric in this epoch. """ self.model.train() self.train_metrics.reset() tic = time.time() datatime = batchtime = 0 for batch_idx, batch in enumerate(self.trainloader): datatime += time.time() - tic # ------------------------------------------------------------------------- loss, mets = self._train_step(batch) # ------------------------------------------------------------------------- self.optimizer.zero_grad() loss.backward() self.optimizer.step() batchtime += time.time() - tic tic = time.time() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) self.train_metrics.update('loss', loss.item()) for key, val in mets.items(): self.train_metrics.update(key, val) if batch_idx % self.log_step == 0: processed_percent = batch_idx / self.len_epoch * 100 self.logger.debug( 'Train Epoch:{} [{}/{}]({:.0f}%)\tTime:{:5.2f}/{:<5.2f}\tLoss:({:.4f}){:.4f}' .format(epoch, batch_idx, self.len_epoch, processed_percent, datatime, batchtime, loss.item(), self.train_metrics.avg('loss'))) datatime = batchtime = 0 if batch_idx == self.len_epoch: break log = self.train_metrics.result() log = {'train_' + k: v for k, v in log.items()} if self.validloader is not None: val_log = self._valid_epoch(epoch) log.update(**{'valid_' + k: v for k, v in val_log.items()}) return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() for batch_idx, batch in enumerate(self.validloader): # ------------------------------------------------------------------------- loss, mets = self._valid_step(batch) # ------------------------------------------------------------------------- self.writer.set_step( (epoch - 1) * len(self.validloader) + batch_idx, 'valid') self.valid_metrics.update('loss', loss.item()) for key, val in mets.items(): self.valid_metrics.update(key, val) return self.valid_metrics.result() def train(self): """ Full training logic """ not_improved_count = 0 for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) # save logged informations into log dict lr = self.optimizer.param_groups[0]['lr'] log = {'epoch': epoch, 'lr': lr} log.update(result) # print logged informations to the screen for key, value in log.items(): self.logger.info(' {:20s}: {}'.format(str(key), value)) # evaluate model performance according to configured metric, save best checkpoint as model_best best = False if self.mnt_mode != 'off': try: # check whether model performance improved or not, according to specified metric(mnt_metric) improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \ (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best) except KeyError: self.logger.warning( "Warning: Metric '{}' is not found. " "Model performance monitoring is disabled.".format( self.mnt_metric)) self.mnt_mode = 'off' improved = False if improved: self.mnt_best = log[self.mnt_metric] not_improved_count = 0 best = True else: not_improved_count += 1 if not_improved_count > self.early_stop: self.logger.info( "Validation performance didn\'t improve for {} epochs. " "Training stops.".format(self.early_stop)) break if self.lr_scheduler is not None: if isinstance(self.lr_scheduler, ReduceLROnPlateau): self.lr_scheduler.step(log[self.mnt_metric]) else: self.lr_scheduler.step() if epoch % self.save_period == 0: self._save_checkpoint(epoch, save_best=best) # add histogram of model parameters to the tensorboard self.writer.set_step(epoch) for name, p in self.model.named_parameters(): self.writer.add_histogram(name, p, bins='auto') def _save_checkpoint(self, epoch, save_best=False): """ Saving checkpoints :param epoch: current epoch number :param log: logging information of the epoch :param save_best: if True, rename the saved checkpoint to 'model_best.pth' """ state = { 'epoch': epoch, 'model': self.model.module.state_dict(), 'optimizer': self.optimizer.state_dict(), 'lr_scheduler': self.lr_scheduler.state_dict(), 'monitor_best': self.mnt_best } filename = str(self.checkpoint_dir / 'chkpt_{:03d}.pth'.format(epoch)) torch.save(state, filename) self.logger.info("Saving checkpoint: {} ...".format(filename)) if save_best: best_path = str(self.checkpoint_dir / 'model_best.pth') torch.save(state, best_path) self.logger.info("Saving current best: model_best.pth ...") def _resume_checkpoint(self, resume_path): """ Resume from saved checkpoints :param resume_path: Checkpoint path to be resumed """ resume_path = str(resume_path) self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path) try: self.start_epoch = checkpoint['epoch'] + 1 self.model.module.load_state_dict(checkpoint['model']) self.optimizer.load_state_dict(checkpoint['optimizer']) self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) self.mnt_best = checkpoint['monitor_best'] except KeyError: self.model.module.load_state_dict(checkpoint) self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format( self.start_epoch))
class Trainer(BaseTrainer): def __init__(self, config): super(Trainer, self).__init__(config) self.datamanager = DataManger(config["data"]) # model self.model = Baseline( num_classes=self.datamanager.datasource.get_num_classes("train") ) # summary model summary( self.model, input_size=(3, 256, 128), batch_size=config["data"]["batch_size"], device="cpu", ) # losses cfg_losses = config["losses"] self.criterion = Softmax_Triplet_loss( num_class=self.datamanager.datasource.get_num_classes("train"), margin=cfg_losses["margin"], epsilon=cfg_losses["epsilon"], use_gpu=self.use_gpu, ) self.center_loss = CenterLoss( num_classes=self.datamanager.datasource.get_num_classes("train"), feature_dim=2048, use_gpu=self.use_gpu, ) # optimizer cfg_optimizer = config["optimizer"] self.optimizer = torch.optim.Adam( self.model.parameters(), lr=cfg_optimizer["lr"], weight_decay=cfg_optimizer["weight_decay"], ) self.optimizer_centerloss = torch.optim.SGD( self.center_loss.parameters(), lr=0.5 ) # learing rate scheduler cfg_lr_scheduler = config["lr_scheduler"] self.lr_scheduler = WarmupMultiStepLR( self.optimizer, milestones=cfg_lr_scheduler["steps"], gamma=cfg_lr_scheduler["gamma"], warmup_factor=cfg_lr_scheduler["factor"], warmup_iters=cfg_lr_scheduler["iters"], warmup_method=cfg_lr_scheduler["method"], ) # track metric self.train_metrics = MetricTracker("loss", "accuracy") self.valid_metrics = MetricTracker("loss", "accuracy") # save best accuracy for function _save_checkpoint self.best_accuracy = None # send model to device self.model.to(self.device) self.scaler = GradScaler() # resume model from last checkpoint if config["resume"] != "": self._resume_checkpoint(config["resume"]) def train(self): for epoch in range(self.start_epoch, self.epochs + 1): result = self._train_epoch(epoch) if self.lr_scheduler is not None: self.lr_scheduler.step() result = self._valid_epoch(epoch) # add scalars to tensorboard self.writer.add_scalars( "Loss", { "Train": self.train_metrics.avg("loss"), "Val": self.valid_metrics.avg("loss"), }, global_step=epoch, ) self.writer.add_scalars( "Accuracy", { "Train": self.train_metrics.avg("accuracy"), "Val": self.valid_metrics.avg("accuracy"), }, global_step=epoch, ) # logging result to console log = {"epoch": epoch} log.update(result) for key, value in log.items(): self.logger.info(" {:15s}: {}".format(str(key), value)) # save model if ( self.best_accuracy == None or self.best_accuracy < self.valid_metrics.avg("accuracy") ): self.best_accuracy = self.valid_metrics.avg("accuracy") self._save_checkpoint(epoch, save_best=True) else: self._save_checkpoint(epoch, save_best=False) # save logs self._save_logs(epoch) def _train_epoch(self, epoch): """Training step""" self.model.train() self.train_metrics.reset() with tqdm(total=len(self.datamanager.get_dataloader("train"))) as epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch}") for batch_idx, (data, labels, _) in enumerate( self.datamanager.get_dataloader("train") ): # push data to device data, labels = data.to(self.device), labels.to(self.device) # zero gradient self.optimizer.zero_grad() self.optimizer_centerloss.zero_grad() with autocast(): # forward batch score, feat = self.model(data) # calculate loss and accuracy loss = ( self.criterion(score, feat, labels) + self.center_loss(feat, labels) * self.config["losses"]["beta"] ) _, preds = torch.max(score.data, dim=1) # backward parameters # loss.backward() self.scaler.scale(loss).backward() # backward parameters for center_loss for param in self.center_loss.parameters(): param.grad.data *= 1.0 / self.config["losses"]["beta"] # optimize # self.optimizer.step() self.scaler.step(self.optimizer) self.optimizer_centerloss.step() self.scaler.update() # update loss and accuracy in MetricTracker self.train_metrics.update("loss", loss.item()) self.train_metrics.update( "accuracy", torch.sum(preds == labels.data).double().item() / data.size(0), ) # update process bar epoch_pbar.set_postfix( { "train_loss": self.train_metrics.avg("loss"), "train_acc": self.train_metrics.avg("accuracy"), } ) epoch_pbar.update(1) return self.train_metrics.result() def _valid_epoch(self, epoch): """Validation step""" self.model.eval() self.valid_metrics.reset() with torch.no_grad(): with tqdm(total=len(self.datamanager.get_dataloader("val"))) as epoch_pbar: epoch_pbar.set_description(f"Epoch {epoch}") for batch_idx, (data, labels, _) in enumerate( self.datamanager.get_dataloader("val") ): # push data to device data, labels = data.to(self.device), labels.to(self.device) with autocast(): # forward batch score, feat = self.model(data) # calculate loss and accuracy loss = ( self.criterion(score, feat, labels) + self.center_loss(feat, labels) * self.config["losses"]["beta"] ) _, preds = torch.max(score.data, dim=1) # update loss and accuracy in MetricTracker self.valid_metrics.update("loss", loss.item()) self.valid_metrics.update( "accuracy", torch.sum(preds == labels.data).double().item() / data.size(0), ) # update process bar epoch_pbar.set_postfix( { "val_loss": self.valid_metrics.avg("loss"), "val_acc": self.valid_metrics.avg("accuracy"), } ) epoch_pbar.update(1) return self.valid_metrics.result() def _save_checkpoint(self, epoch, save_best=True): """save model to file""" state = { "epoch": epoch, "state_dict": self.model.state_dict(), "center_loss": self.center_loss.state_dict(), "optimizer": self.optimizer.state_dict(), "optimizer_centerloss": self.optimizer_centerloss.state_dict(), "lr_scheduler": self.lr_scheduler.state_dict(), "best_accuracy": self.best_accuracy, } filename = os.path.join(self.checkpoint_dir, "model_last.pth") self.logger.info("Saving last model: model_last.pth ...") torch.save(state, filename) if save_best: filename = os.path.join(self.checkpoint_dir, "model_best.pth") self.logger.info("Saving current best: model_best.pth ...") torch.save(state, filename) def _resume_checkpoint(self, resume_path): """Load model from checkpoint""" if not os.path.exists(resume_path): raise FileExistsError("Resume path not exist!") self.logger.info("Loading checkpoint: {} ...".format(resume_path)) checkpoint = torch.load(resume_path, map_location=self.map_location) self.start_epoch = checkpoint["epoch"] + 1 self.model.load_state_dict(checkpoint["state_dict"]) self.center_loss.load_state_dict(checkpoint["center_loss"]) self.optimizer.load_state_dict(checkpoint["optimizer"]) self.optimizer_centerloss.load_state_dict(checkpoint["optimizer_centerloss"]) self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) self.best_accuracy = checkpoint["best_accuracy"] self.logger.info( "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch) ) def _save_logs(self, epoch): """Save logs from google colab to google drive""" if os.path.isdir(self.logs_dir_saved): shutil.rmtree(self.logs_dir_saved) destination = shutil.copytree(self.logs_dir, self.logs_dir_saved)
class ClassificationTrainer(LayerwiseTrainer): def __init__(self, model, criterions, metric_ftns, optimizer, config, train_data_loader, valid_data_loader=None, lr_scheduler=None, weight_scheduler=None, test_data_loader=None): super().__init__(model, criterions, metric_ftns, optimizer, config, train_data_loader, valid_data_loader, lr_scheduler, weight_scheduler) self.train_teacher_metrics = MetricTracker( *[m.__name__ for m in self.metric_ftns], writer=self.writer) self.valid_metrics = MetricTracker( 'loss', 'supervised_loss', 'kd_loss', 'hint_loss', 'teacher_loss', *[m.__name__ for m in self.metric_ftns], *['teacher_' + m.__name__ for m in self.metric_ftns], writer=self.writer) self.test_data_loader = test_data_loader def _train_epoch(self, epoch): self.prepare_train_epoch(epoch) self.model.train() self._clean_cache() for batch_idx, (data, target) in enumerate(self.train_data_loader): data, target = data.to(self.device), target.to(self.device) output_st, output_tc = self.model(data) supervised_loss = self.criterions[0]( output_st, target) / self.accumulation_steps kd_loss = self.criterions[1](output_st, output_tc) / self.accumulation_steps hint_loss = reduce( lambda acc, elem: acc + self.criterions[2](elem[0], elem[1]), zip(self.model.student_hidden_outputs, self.model.teacher_hidden_outputs), torch.tensor(0)) / self.accumulation_steps teacher_loss = self.criterions[0](output_tc, target) # for comparision # Only use hint loss loss = kd_loss loss.backward() if (batch_idx + 1) % self.accumulation_steps == 0: self.optimizer.step() self.optimizer.zero_grad() self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx) # update metrics self.train_metrics.update('loss', loss.item() * self.accumulation_steps) self.train_metrics.update( 'supervised_loss', supervised_loss.item() * self.accumulation_steps) self.train_metrics.update('kd_loss', kd_loss.item() * self.accumulation_steps) self.train_metrics.update( 'hint_loss', hint_loss.item() * self.accumulation_steps) self.train_metrics.update('teacher_loss', teacher_loss.item()) for met in self.metric_ftns: self.train_metrics.update(met.__name__, met(output_st, target), data.shape[0]) for met in self.metric_ftns: self.train_teacher_metrics.update(met.__name__, met(output_tc, target), data.shape[0]) if batch_idx % self.log_step == 0: # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True)) self.logger.info( 'Train Epoch: {} [{}]/[{}] acc: {:.6f} teacher_acc: {:.6f} Loss: {:.6f} Supervised Loss: {:.6f} ' 'Knowledge Distillation loss: {:.6f} Hint Loss: {:.6f} Teacher Loss: {:.6f}' .format( epoch, batch_idx, self.len_epoch, self.train_metrics.avg('accuracy'), self.train_teacher_metrics.avg('accuracy'), self.train_metrics.avg('loss'), self.train_metrics.avg('supervised_loss'), self.train_metrics.avg('kd_loss'), self.train_metrics.avg('hint_loss'), self.train_metrics.avg('teacher_loss'), )) if batch_idx == self.len_epoch: break log = self.train_metrics.result() if self.do_validation and ((epoch % self.do_validation_interval) == 0): # clean cache to prevent out-of-memory with 1 gpu self._clean_cache() val_log = self._valid_epoch(epoch) log.update(**{'val_' + k: v for k, v in val_log.items()}) if (self.lr_scheduler is not None) and (not isinstance( self.lr_scheduler, MyOneCycleLR)): if isinstance(self.lr_scheduler, MyReduceLROnPlateau): self.lr_scheduler.step(self.train_metrics.avg('loss')) else: self.lr_scheduler.step() self.weight_scheduler.step() return log def _valid_epoch(self, epoch): """ Validate after training an epoch :param epoch: Integer, current training epoch. :return: A log that contains information about validation """ self.model.eval() self.valid_metrics.reset() with torch.no_grad(): for batch_idx, (data, target) in enumerate(self.valid_data_loader): data, target = data.to(self.device), target.to(self.device) output, output_tc = self.model(data) self.writer.set_step( (epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid') for met in self.metric_ftns: self.valid_metrics.update(met.__name__, met(output, target), data.shape[0]) for met in self.metric_ftns: self.valid_metrics.update('teacher_' + met.__name__, met(output_tc, target), data.shape[0]) return self.valid_metrics.result()