class Trainer(BaseTrainer):
    r""" Trainer for person attribute recognition
    """
    def __init__(self, config):
        super(Trainer, self).__init__(config)
        # Datamanager
        self.datamanager, params_data = build_datamanager(
            config['type'], config['data'])

        # model
        self.model, params_model = build_model(
            config,
            num_classes=len(self.datamanager.datasource.get_attribute()),
            device=self.device)

        # losses
        pos_ratio = torch.tensor(
            self.datamanager.datasource.get_weight('train'))
        self.criterion, params_loss = build_losses(
            config,
            pos_ratio=pos_ratio,
            num_attribute=len(self.datamanager.datasource.get_attribute()))

        # optimizer
        self.optimizer, params_optimizers = build_optimizers(
            config, self.model)

        # learing rate scheduler
        self.lr_scheduler, params_lr_scheduler = build_lr_scheduler(
            config, self.optimizer)

        # callbacks for freeze backbone
        if config['freeze']['enable']:
            self.freeze = FreezeLayers(self.model, config['freeze']['layers'],
                                       config['freeze']['epochs'])
        else:
            self.freeze = None

        # list of metrics
        self.lst_metrics = ['mA', 'accuracy', 'f1_score']

        # track metric
        self.train_metrics = MetricTracker('loss', *self.lst_metrics)
        self.valid_metrics = MetricTracker('loss', *self.lst_metrics)

        # step log loss and accuracy
        self.log_step = (len(self.datamanager.get_dataloader('train')) // 5,
                         len(self.datamanager.get_dataloader('val')) // 5)
        self.log_step = (self.log_step[0] if self.log_step[0] > 0 else 1,
                         self.log_step[1] if self.log_step[1] > 0 else 1)

        # best accuracy and loss
        self.best_loss = None
        self.best_metrics = dict()
        for x in self.lst_metrics:
            self.best_metrics[x] = None

        # print config
        self._print_config(
            params_data=params_data,
            params_model=params_model,
            params_loss=params_loss,
            params_optimizers=params_optimizers,
            params_lr_scheduler=params_lr_scheduler,
            freeze_layers=False if self.freeze == None else True,
            clip_grad_norm_=self.config['clip_grad_norm_']['enable'])

        # send model to device
        self.model.to(self.device)
        self.criterion.to(self.device)

        # summary model
        summary(model=self.model,
                input_data=torch.zeros((self.datamanager.get_batch_size(), 3,
                                        self.datamanager.get_image_size()[0],
                                        self.datamanager.get_image_size()[1])),
                batch_dim=None,
                device='cuda' if self.use_gpu else 'cpu',
                print_func=self.logger.info,
                print_step=False)

        # resume model from last checkpoint
        if config['resume'] != '':
            self._resume_checkpoint(config['resume'], config['only_model'])

    def train(self):
        # begin train
        for epoch in range(self.start_epoch, self.epochs + 1):
            # freeze layer
            if self.freeze != None:
                self.freeze.on_epoch_begin(epoch)

            # train
            result = self._train_epoch(epoch)

            # valid
            result = self._valid_epoch(epoch)

            # learning rate
            if self.lr_scheduler is not None:
                if self.config['lr_scheduler']['start'] <= epoch:
                    if isinstance(self.lr_scheduler,
                                  torch.optim.lr_scheduler.ReduceLROnPlateau):
                        self.lr_scheduler.step(self.valid_metrics.avg('loss'))
                    else:
                        self.lr_scheduler.step()

            # add scalars to tensorboard
            self.writer.add_scalars('Loss', {
                'Train': self.train_metrics.avg('loss'),
                'Val': self.valid_metrics.avg('loss')
            },
                                    global_step=epoch)

            for metric in self.lst_metrics:
                self.writer.add_scalars(metric, {
                    'Train': self.train_metrics.avg(metric),
                    'Val': self.valid_metrics.avg(metric)
                },
                                        global_step=epoch)

            self.writer.add_scalar('lr',
                                   self.optimizer.param_groups[-1]['lr'],
                                   global_step=epoch)

            # logging result to console
            log = {'epoch': epoch}
            log.update(result)
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))

            # save model
            save_best_loss = False
            if self.best_loss == None or self.best_loss >= self.valid_metrics.avg(
                    'loss'):
                self.best_loss = self.valid_metrics.avg('loss')
                save_best_loss = True

            save_best = dict()
            for metric in self.lst_metrics:
                save_best[metric] = False
                if self.best_metrics[metric] == None or self.best_metrics[
                        metric] <= self.valid_metrics.avg(metric):
                    self.best_metrics[metric] = self.valid_metrics.avg(metric)
                    save_best[metric] = True

            self._save_checkpoint(epoch, save_best_loss, save_best)

            # save logs to drive if using colab
            if self.config['colab']:
                self._save_logs()

        # wait for tensorboard flush all metrics to file
        self.writer.flush()
        # time.sleep(1*60)
        self.writer.close()
        # save logs to drive if using colab
        if self.config['colab']:
            self._save_logs()
        # plot loss, accuracy and save them to plot.png in saved/logs/<run_id>/plot.png
        plot_loss_accuracy(
            dpath=self.cfg_trainer['log_dir'],
            list_dname=[self.run_id],
            path_folder=self.logs_dir_saved
            if self.config['colab'] == True else self.logs_dir,
            title=self.run_id + ': ' + self.config['model']['name'] + ", " +
            self.config['loss']['name'] + ", " + self.config['data']['name'])

    def _train_epoch(self, epoch):
        r""" Training step
        """
        raise NotImplementedError

    def _valid_epoch(self, epoch):
        r""" Validation step
        """
        raise NotImplementedError

    def test(self):
        r""" Test model after train
        TODO:
        """
        logger = logging.getLogger('test')

        self.model.eval()
        preds = []
        labels = []

        if self.cfg_trainer['use_tqdm']:
            tqdm_callback = Tqdm(
                total=len(self.datamanager.get_dataloader('test')))
        with torch.no_grad():
            for batch_idx, (data, _labels) in enumerate(
                    self.datamanager.get_dataloader('test')):
                if batch_idx == 5:
                    break
                data, _labels = data.to(self.device), _labels.to(self.device)

                out = self.model(data)

                _preds = torch.sigmoid(out)
                preds.append(_preds)
                labels.append(_labels)
                if self.cfg_trainer['use_tqdm']:
                    tqdm_callback.update()
                else:
                    if (batch_idx + 1) % (
                            len(self.datamanager.get_dataloader('test')) // 10
                            + 1) or (batch_idx + 1) == len(
                                self.datamanager.get_dataloader('test')) - 1:
                        logger.info('Iter {}/{}'.format(
                            batch_idx + 1,
                            len(self.datamanager.get_dataloader('test'))))

        preds = torch.cat(preds, dim=0)
        labels = torch.cat(labels, dim=0)
        preds = preds.cpu().numpy()
        labels = labels.cpu().numpy()

        result_label, result_instance = recognition_metrics(labels, preds)

        log_test(logger.info, self.datamanager.datasource.get_attribute(),
                 self.datamanager.datasource.get_weight('test'), result_label,
                 result_instance)

    def _save_checkpoint(self, epoch, save_best_loss, save_best_metrics):
        r""" Save model to file
        """
        state = {
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'loss': self.criterion.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
            'best_loss': self.best_loss
        }
        for metric in self.lst_metrics:
            state.update({'best_{}'.format(metric): self.best_metrics[metric]})

        filename = os.path.join(self.checkpoint_dir, 'model_last.pth')
        self.logger.info("Saving last model: model_last.pth ...")
        torch.save(state, filename)

        if save_best_loss:
            filename = os.path.join(self.checkpoint_dir, 'model_best_loss.pth')
            self.logger.info(
                "Saving current best loss: model_best_loss.pth ...")
            torch.save(state, filename)

        for metric in self.lst_metrics:
            if save_best_metrics[metric]:
                filename = os.path.join(self.checkpoint_dir,
                                        'model_best_{}.pth'.format(metric))
                self.logger.info(
                    "Saving current best {}: model_best_{}.pth ...".format(
                        metric, metric))
                torch.save(state, filename)

    def _resume_checkpoint(self, resume_path, only_model=False):
        r""" Load model from checkpoint
        """
        if not os.path.exists(resume_path):
            raise FileExistsError("Resume path not exist!")
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=self.map_location)
        self.model.load_state_dict(checkpoint['state_dict'])
        if only_model:
            self.logger.info("Pretrained-model loaded!")
            return
        self.start_epoch = checkpoint['epoch'] + 1
        self.criterion.load_state_dict(checkpoint['loss'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])
        self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
        self.best_loss = checkpoint['best_loss']
        for metric in self.lst_metrics:
            self.best_metrics[metric] = checkpoint['best_{}'.format(metric)]
        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))

    def _print_config(self,
                      params_data=None,
                      params_model=None,
                      params_loss=None,
                      params_optimizers=None,
                      params_lr_scheduler=None,
                      freeze_layers=False,
                      clip_grad_norm_=False):
        r""" print config into log file
        """
        def __prams_to_str(params: dict):
            if params == None:
                return ''
            row_format = "{:>4},  " * len(params)
            return row_format.format(
                *[key + ': ' + str(value) for key, value in params.items()])

        self.logger.info('Run id: %s' % (self.run_id))
        self.logger.info('Data: ' + __prams_to_str(params_data))
        self.logger.info('Model: %s ' % (self.config['model']['name']) +
                         __prams_to_str(params_model))
        if freeze_layers:
            self.logger.info('Freeze layer: %s, at first epoch %d' %
                             (str(self.config['freeze']['layers']),
                              self.config['freeze']['epochs']))
        self.logger.info('Loss: %s ' % (self.config['loss']['name']) +
                         __prams_to_str(params_loss))
        self.logger.info('Optimizer: %s ' %
                         (self.config['optimizer']['name']) +
                         __prams_to_str(params_optimizers))
        if params_lr_scheduler != None:
            self.logger.info('Lr scheduler: %s ' %
                             (self.config['lr_scheduler']['name']) +
                             __prams_to_str(params_lr_scheduler))
        if clip_grad_norm_:
            self.logger.info('clip_grad_norm_, max_norm: %f' %
                             self.config['clip_grad_norm_']['max_norm'])
class SegmentationTrainer(BaseTrainer):
    def __init__(self, model, criterion, metrics, optimizer, config, lr_scheduler=None):
        super().__init__(model, criterion, metrics, optimizer, config)
        self.lr_scheduler = lr_scheduler
        self.loss_name = 'supervised_loss'

        # Metrics
        # Train
        self.train_loss = MetricTracker(self.loss_name, self.writer)
        self.train_metrics = MetricTracker(*self.metric_names,
                                           self.writer)
        # Validation
        self.valid_loss = MetricTracker(self.loss_name, self.writer)
        self.valid_metrics = MetricTracker(*self.metric_names,
                                           self.writer)
        # Test
        self.test_loss = MetricTracker(self.loss_name, self.writer)
        self.test_metrics = MetricTracker(*self.metric_names,
                                          self.writer)

        if isinstance(self.model, nn.DataParallel):
            self.criterion = nn.DataParallel(self.criterion)

        # Resume checkpoint if path is available in config
        cp_path = self.config['trainer'].get('resume_path')
        if cp_path:
            super()._resume_checkpoint()

    def reset_scheduler(self):
        self.train_loss.reset()
        self.train_metrics.reset()
        self.valid_loss.reset()
        self.valid_metrics.reset()
        self.test_loss.reset()
        self.test_metrics.reset()
        # if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
        #     self.lr_scheduler.reset()

    def prepare_train_epoch(self, epoch):
        self.logger.info('EPOCH: {}'.format(epoch))
        self.reset_scheduler()

    def _train_epoch(self, epoch):
        self.model.train()
        self.prepare_train_epoch(epoch)
        for batch_idx, (data, target, image_name) in enumerate(self.train_data_loader):
            data, target = data.to(self.device), target.to(self.device)
            output = self.model(data)
            loss = self.criterion(output, target)
            # For debug model
            if torch.isnan(loss):
                super()._save_checkpoint(epoch)

            self.model.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update train loss, metrics
            self.train_loss.update(self.loss_name, loss.item())
            for metric in self.metrics:
                self.train_metrics.update(metric.__name__, metric(output, target), n=output.shape[0])

            if batch_idx % self.log_step == 0:
                self.log_for_step(epoch, batch_idx)

            if self.save_for_track and (batch_idx % self.save_for_track == 0):
                save_output(output, image_name, epoch, self.checkpoint_dir)

            if batch_idx == self.len_epoch:
                break

        log = self.train_loss.result()
        log.update(self.train_metrics.result())

        if self.do_validation and (epoch % self.do_validation_interval == 0):
            val_log = self._valid_epoch(epoch)
            log.update(val_log)

        # step lr scheduler
        if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
            self.lr_scheduler.step(self.valid_loss.avg(self.loss_name))

        return log

    @staticmethod
    def get_metric_message(metrics, metric_names):
        metrics_avg = [metrics.avg(name) for name in metric_names]
        message_metrics = ', '.join(['{}: {:.6f}'.format(x, y) for x, y in zip(metric_names, metrics_avg)])
        return message_metrics

    def log_for_step(self, epoch, batch_idx):
        message_loss = 'Train Epoch: {} [{}]/[{}] Dice Loss: {:.6f}'.format(epoch, batch_idx, self.len_epoch,
                                                                            self.train_loss.avg(self.loss_name))

        message_metrics = SegmentationTrainer.get_metric_message(self.train_metrics, self.metric_names)
        self.logger.info(message_loss)
        self.logger.info(message_metrics)

    def _valid_epoch(self, epoch, save_result=False, save_for_visual=False):
        self.model.eval()
        self.valid_loss.reset()
        self.valid_metrics.reset()
        self.logger.info('Validation: ')
        with torch.no_grad():
            for batch_idx, (data, target, image_name) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
                self.valid_loss.update(self.loss_name, loss.item())
                for metric in self.metrics:
                    self.valid_metrics.update(metric.__name__, metric(output, target), n=output.shape[0])

                if save_result:
                    save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1)

                if save_for_visual:
                    save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output'))
                    save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target'))

                if batch_idx % self.log_step == 0:
                    self.logger.debug('{}/{}'.format(batch_idx, len(self.valid_data_loader)))
                    self.logger.debug('{}: {}'.format(self.loss_name, self.valid_loss.avg(self.loss_name)))
                    self.logger.debug(SegmentationTrainer.get_metric_message(self.valid_metrics, self.metric_names))

        log = self.valid_loss.result()
        log.update(self.valid_metrics.result())
        val_log = {'val_{}'.format(k): v for k, v in log.items()}
        return val_log

    def _test_epoch(self, epoch, save_result=False, save_for_visual=False):
        self.model.eval()
        self.test_loss.reset()
        self.test_metrics.reset()
        self.logger.info('Test: ')
        with torch.no_grad():
            for batch_idx, (data, target, image_name) in enumerate(self.test_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)

                self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
                self.test_loss.update(self.loss_name, loss.item())
                for metric in self.metrics:
                    self.test_metrics.update(metric.__name__, metric(output, target), n=output.shape[0])

                if save_result:
                    save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1)

                if save_for_visual:
                    save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output'))
                    save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target'))

                if batch_idx % self.log_step == 0:
                    self.logger.debug('{}/{}'.format(batch_idx, len(self.test_data_loader)))
                    self.logger.debug('{}: {}'.format(self.loss_name, self.test_loss.avg(self.loss_name)))
                    self.logger.debug(SegmentationTrainer.get_metric_message(self.test_metrics, self.metric_names))

        log = self.test_loss.result()
        log.update(self.test_metrics.result())
        test_log = {'test_{}'.format(k): v for k, v in log.items()}
        return test_log
Exemple #3
0
class LayerwiseTrainer(BaseTrainer):
    """
    Trainer
    """
    def __init__(self,
                 model: DepthwiseStudent,
                 criterions,
                 metric_ftns,
                 optimizer,
                 config,
                 train_data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 weight_scheduler=None):
        super().__init__(model, None, metric_ftns, optimizer, config)
        self.config = config
        self.train_data_loader = train_data_loader
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_validation_interval = self.config['trainer'][
            'do_validation_interval']
        self.lr_scheduler = lr_scheduler
        self.weight_scheduler = weight_scheduler
        self.log_step = config['trainer']['log_step']
        if "len_epoch" in self.config['trainer']:
            # iteration-based training
            self.train_data_loader = inf_loop(train_data_loader)
            self.len_epoch = self.config['trainer']['len_epoch']
        else:
            # epoch-based training
            self.len_epoch = len(self.train_data_loader)

        # Metrics
        # Train
        self.train_metrics = MetricTracker(
            'loss',
            'supervised_loss',
            'kd_loss',
            'hint_loss',
            'teacher_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.train_iou_metrics = CityscapesMetricTracker(writer=self.writer)
        self.train_teacher_iou_metrics = CityscapesMetricTracker(
            writer=self.writer)
        # Valid
        self.valid_metrics = MetricTracker(
            'loss',
            'supervised_loss',
            'kd_loss',
            'hint_loss',
            'teacher_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_iou_metrics = CityscapesMetricTracker(writer=self.writer)
        # Test
        self.test_metrics = MetricTracker(
            'loss',
            'supervised_loss',
            'kd_loss',
            'hint_loss',
            'teacher_loss',
            *[m.__name__ for m in self.metric_ftns],
            *['teacher_' + m.__name__ for m in self.metric_ftns],
            writer=self.writer,
        )
        self.test_iou_metrics = CityscapesMetricTracker(writer=self.writer)

        # Tracker for early stop if val miou doesn't increase
        self.val_iou_tracker = EarlyStopTracker('best', 'max', 0.01, 'rel')

        # Only used list of criterions and remove the unused property
        self.criterions = criterions
        self.criterions = nn.ModuleList(self.criterions).to(self.device)
        if isinstance(self.model, nn.DataParallel):
            self.criterions = nn.DataParallel(self.criterions)
        del self.criterion

        # Resume checkpoint if path is available in config
        if 'resume_path' in self.config['trainer']:
            self.resume(self.config['trainer']['resume_path'])

    def prepare_train_epoch(self, epoch, config=None):
        """
        Prepare before training an epoch i.e. prune new layer, unfreeze some layers, create new optimizer ....
        :param epoch:  int - indicate which epoch the trainer's in
        :param config: a config object that contain pruning_plan, hint, unfreeze information
        :return: 
        """
        # if the config is not set (training normaly, then set config to current trainer config)
        # if the config is set (in case you're resuming a checkpoint) then use saved config to replace
        #    layers in student so that it would have identical archecture with saved checkpoint
        if config is None:
            config = self.config
        # reset_scheduler
        self.reset_scheduler()
        # there isn't any layer that would be replaced or unfreeze or set as hint then unfreeze
        # the whole network
        if (epoch == 1) and ((len(config['pruning']['pruning_plan']) +
                              len(config['pruning']['hint']) +
                              len(config['pruning']['unfreeze'])) == 0):
            self.logger.debug(
                'Train a student with identical architecture with teacher')
            # unfreeze
            for param in self.model.student.parameters():
                param.requires_grad = True
            # debug
            self.logger.info(self.model.dump_trainable_params())
            # create optimizer for the network
            self.create_new_optimizer()
            # ignore all below stuff
            return

        # Check if there is any layer that would any update in current epoch
        # list of epochs that would have an update on student networks
        epochs = list(
            map(
                lambda x: x['epoch'], config['pruning']['pruning_plan'] +
                config['pruning']['hint'] + config['pruning']['unfreeze']))
        # if there isn't any update then simply return
        if epoch not in epochs:
            self.logger.info('EPOCH: ' + str(epoch))
            self.logger.info('There is no update ...')
            return

        # layers that would be replaced by depthwise separable conv
        replaced_layers = list(
            filter(lambda x: x['epoch'] == epoch,
                   config['pruning']['pruning_plan']))
        # layers which outputs will be used as loss
        hint_layers = list(
            map(
                lambda x: x['name'],
                filter(lambda x: x['epoch'] == epoch,
                       config['pruning']['hint'])))
        # layers that would be trained in this epoch
        unfreeze_layers = list(
            map(
                lambda x: x['name'],
                filter(lambda x: x['epoch'] == epoch,
                       config['pruning']['unfreeze'])))
        self.logger.info('EPOCH: ' + str(epoch))
        self.logger.info('Replaced layers: ' + str(replaced_layers))
        self.logger.info('Hint layers: ' + str(hint_layers))
        self.logger.info('Unfreeze layers: ' + str(unfreeze_layers))
        # Avoid error when loading deprecate checkpoint which don't have 'args' in config.pruning
        if 'args' in config['pruning']:
            kwargs = config['pruning']['args']
        else:
            self.logger.warning('Using deprecate checkpoint...')
            kwargs = config['pruning']['pruner']

        self.model.replace(
            replaced_layers,
            **kwargs)  # replace those layers with depthwise separable conv
        self.model.register_hint_layers(
            hint_layers
        )  # assign which layers output would be used as hint loss
        self.model.unfreeze(unfreeze_layers)  # unfreeze chosen layers

        if epoch == 1:
            self.create_new_optimizer(
            )  # create new optimizer to remove the effect of momentum
        else:
            self.update_optimizer(
                list(
                    filter(lambda x: x['epoch'] == epoch,
                           config['pruning']['unfreeze'])))

        self.logger.info(self.model.dump_trainable_params())
        self.logger.info(self.model.dump_student_teacher_blocks_info())

    def update_optimizer(self, unfreeze_config):
        """
        Update param groups for optimizer with unfreezed layers of this epoch
        :param unfreeze_config - list of arg. Each arg is the dictionary with following format:
            {'name': 'layer1', 'epoch':1, 'lr'(optional): 0.01}
        return: 
        """
        if len(unfreeze_config) > 0:
            self.logger.debug('Updating optimizer for new layer')
        for config in unfreeze_config:
            layer_name = config['name']  # layer that will be unfreezed
            self.logger.debug(
                'Add parameters of layer: {} to optimizer'.format(layer_name))

            layer = self.model.get_block(
                layer_name,
                self.model.student)  # actual layer i.e. nn.Module obj
            optimizer_arg = self.config['optimizer'][
                'args']  # default args for optimizer

            # we can also specify layerwise learning !
            if "lr" in config:
                optimizer_arg['lr'] = config['lr']
            # add unfreezed layer's parameters to optimizer
            self.optimizer.add_param_group({
                'params': layer.parameters(),
                **optimizer_arg
            })

    def create_new_optimizer(self):
        """
        Create new optimizer if trainer is in epoch 1 otherwise just run update optimizer
        """
        # Create new optimizer
        self.logger.debug('Creating new optimizer ...')
        self.optimizer = self.config.init_obj(
            'optimizer', optim_module,
            list(
                filter(lambda x: x.requires_grad,
                       self.model.student.parameters())))
        self.lr_scheduler = self.config.init_obj('lr_scheduler',
                                                 optim_module.lr_scheduler,
                                                 self.optimizer)

    def reset_scheduler(self):
        """
        reset all schedulers, metrics, trackers, etc when unfreeze new layer
        :return:
        """
        self.weight_scheduler.reset()  # weight between loss
        self.val_iou_tracker.reset()  # verify val iou would increase each time
        self.train_metrics.reset()  # metrics for loss,... in training phase
        self.valid_metrics.reset()  # metrics for loss,... in validating phase
        self.train_iou_metrics.reset()  # train iou of student
        self.valid_iou_metrics.reset()  # val iou of student
        self.train_teacher_iou_metrics.reset()  # train iou of teacher
        if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
            self.lr_scheduler.reset()

    def _train_epoch(self, epoch):
        """
        Training logic for 1 epoch
        """
        # Prepare the network i.e. unfreezed new layers, replaced new layer with depthwise separable conv, ...
        self.prepare_train_epoch(epoch)

        # reset
        # FIXME:
        # as the teacher network contain batchnorm layer and our resources are limited to train with
        # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having
        # small batch size
        # self.model.train()
        self.train_iou_metrics.reset()
        self.train_teacher_iou_metrics.reset()
        self._clean_cache()

        for batch_idx, (data, target) in enumerate(self.train_data_loader):
            data, target = data.to(self.device), target.to(self.device)

            output_st, output_tc = self.model(data)

            supervised_loss = self.criterions[0](
                output_st, target) / self.accumulation_steps
            kd_loss = self.criterions[1](output_st,
                                         output_tc) / self.accumulation_steps
            teacher_loss = self.criterions[0](output_tc,
                                              target)  # for comparision

            hint_loss = reduce(
                lambda acc, elem: acc + self.criterions[2](elem[0], elem[1]),
                zip(self.model.student_hidden_outputs,
                    self.model.teacher_hidden_outputs),
                0) / self.accumulation_steps

            # Only use hint loss
            loss = hint_loss
            loss.backward()

            if batch_idx % self.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)

            # update metrics
            self.train_metrics.update('loss',
                                      loss.item() * self.accumulation_steps)
            self.train_metrics.update(
                'supervised_loss',
                supervised_loss.item() * self.accumulation_steps)
            self.train_metrics.update('kd_loss',
                                      kd_loss.item() * self.accumulation_steps)
            self.train_metrics.update(
                'hint_loss',
                hint_loss.item() * self.accumulation_steps)
            self.train_metrics.update('teacher_loss', teacher_loss.item())
            self.train_iou_metrics.update(output_st.detach().cpu(),
                                          target.cpu())
            self.train_teacher_iou_metrics.update(output_tc.cpu(),
                                                  target.cpu())

            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output_st, target))

            if batch_idx % self.log_step == 0:
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                # st_masks = visualize.viz_pred_cityscapes(output_st)
                # tc_masks = visualize.viz_pred_cityscapes(output_tc)
                # self.writer.add_image('st_pred', make_grid(st_masks, nrow=8, normalize=False))
                # self.writer.add_image('tc_pred', make_grid(tc_masks, nrow=8, normalize=False))
                self.logger.info(
                    'Train Epoch: {} [{}]/[{}] Loss: {:.6f} mIoU: {:.6f} Teacher mIoU: {:.6f} Supervised Loss: {:.6f} '
                    'Knowledge Distillation loss: '
                    '{:.6f} Hint Loss: {:.6f} Teacher Loss: {:.6f}'.format(
                        epoch,
                        batch_idx,
                        self.len_epoch,
                        self.train_metrics.avg('loss'),
                        self.train_iou_metrics.get_iou(),
                        self.train_teacher_iou_metrics.get_iou(),
                        self.train_metrics.avg('supervised_loss'),
                        self.train_metrics.avg('kd_loss'),
                        self.train_metrics.avg('hint_loss'),
                        self.train_metrics.avg('teacher_loss'),
                    ))

            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()
        log.update(
            {'train_teacher_mIoU': self.train_teacher_iou_metrics.get_iou()})
        log.update({'train_student_mIoU': self.train_iou_metrics.get_iou()})

        if self.do_validation and (
            (epoch % self.config["trainer"]["do_validation_interval"]) == 0):
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})
            log.update(**{'val_mIoU': self.valid_iou_metrics.get_iou()})
            self.val_iou_tracker.update(self.valid_iou_metrics.get_iou())

        self._teacher_student_iou_gap = self.train_teacher_iou_metrics.get_iou(
        ) - self.train_iou_metrics.get_iou()

        # step lr scheduler
        if (self.lr_scheduler is not None) and (not isinstance(
                self.lr_scheduler, MyOneCycleLR)):
            if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
                self.lr_scheduler.step(self.train_metrics.avg('loss'))
            else:
                self.lr_scheduler.step()
                self.logger.debug('stepped lr')
                for param_group in self.optimizer.param_groups:
                    self.logger.debug(param_group['lr'])

        # anneal weight between losses
        self.weight_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self._clean_cache()
        # FIXME:
        # as the teacher network contain batchnorm layer and our resources are limited to train with
        # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having
        # small batch size
        # self.model.eval()
        self.model.save_hidden = False  # stop saving hidden output
        self.valid_metrics.reset()
        self.valid_iou_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model.inference(data)
                supervised_loss = self.criterions[0](output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('supervised_loss',
                                          supervised_loss.item())
                self.valid_iou_metrics.update(output.detach().cpu(), target)
                self.logger.debug(
                    str(batch_idx) + " : " +
                    str(self.valid_iou_metrics.get_iou()))

                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output, target))
        result = self.valid_metrics.result()
        result['mIoU'] = self.valid_iou_metrics.get_iou()

        return result

    def _test_epoch(self, epoch):
        # cleaning up memory
        self._clean_cache()
        # self.model.eval()
        self.model.save_hidden = False
        self.model.cpu()
        self.model.student.to(self.device)

        # prepare before running submission
        self.test_metrics.reset()
        self.test_iou_metrics.reset()
        args = self.config['test']['args']
        save_4_sm = self.config['submission']['save_output']
        path_output = self.config['submission']['path_output']
        if save_4_sm and not os.path.exists(path_output):
            os.mkdir(path_output)
        n_samples = len(self.valid_data_loader)

        with torch.no_grad():
            for batch_idx, (img_name, data,
                            target) in enumerate(self.valid_data_loader):
                self.logger.info('{}/{}'.format(batch_idx, n_samples))
                data, target = data.to(self.device), target.to(self.device)
                output = self.model.inference_test(data, args)
                if save_4_sm:
                    self.save_for_submission(output, img_name[0])
                supervised_loss = self.criterions[0](output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'test')
                self.test_metrics.update('supervised_loss',
                                         supervised_loss.item())
                self.test_iou_metrics.update(output.detach().cpu(), target)

                for met in self.metric_ftns:
                    self.test_metrics.update(met.__name__, met(output, target))

        result = self.test_metrics.result()
        result['mIoU'] = self.test_iou_metrics.get_iou()

        return result

    def save_for_submission(self, output, image_name, img_type=np.uint8):
        args = self.config['submission']
        path_output = args['path_output']
        image_save = '{}.{}'.format(image_name, args['ext'])
        path_save = os.path.join(path_output, image_save)
        result = torch.argmax(output, dim=1)
        result_mapped = self.re_map_for_submission(result)
        if output.size()[0] == 1:
            result_mapped = result_mapped[0]

        save_image(result_mapped.cpu().numpy().astype(img_type), path_save)
        print('Saved output of test data: {}'.format(image_save))

    def re_map_for_submission(self, output):
        mapping = self.valid_data_loader.dataset.id_to_trainid
        cp_output = torch.zeros(output.size())
        for k, v in mapping.items():
            cp_output[output == v] = k

        return cp_output

    def _clean_cache(self):
        self.model.student_hidden_outputs, self.model.teacher_hidden_outputs = list(
        ), list()
        gc.collect()
        torch.cuda.empty_cache()

    def resume(self, checkpoint_path):
        self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path,
                                map_location=torch.device('cpu'))
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        config = checkpoint['config']  # config of checkpoint
        epoch = checkpoint['epoch']  # stopped epoch

        # load model state from checkpoint
        # first, align the network by replacing depthwise separable for student
        for i in range(1, epoch + 1):
            self.prepare_train_epoch(i, config)
        # load weight
        forgiving_state_restore(self.model, checkpoint['state_dict'])
        self.logger.info("Loaded model's state dict")

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger.warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint. "
                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.logger.info("Loaded optimizer state dict")
Exemple #4
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
                 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.track_loss = ['loss', 'recon', 'kld', 'lmse', 'contrast', 'cycle', 'cycle_mse', 'cycle_ce', 'pseudo', 'klc']

        self.train_metrics = MetricTracker(*self.track_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker(*self.track_loss, *[m.__name__ for m in self.metric_ftns], writer=self.writer)

        self.pitch_map = {i: n for n, i in enumerate(data_loader.dataset.pitch_map)}
        self.dynamic_map = {i: n for n, i in enumerate(data_loader.dataset.dynamic_map)}
        self.pitchclass_map = {i: n for n, i in enumerate(data_loader.dataset.pitchclass_map)}
        self.tf_map = {v: data_loader.dataset.family_map[k] for k,v in data_loader.dataset.instrument_map.items()}

        self.plot_step = 25

        self.recon_sample = np.random.choice(valid_data_loader.sampler.indices, size=10, replace=False)
        pitches = np.random.choice(82, size=len(self.recon_sample))
        self.sample_to_pitch = {k: v for k, v in zip(self.recon_sample, pitches)}

        self.spec_ext = ExtractSpectrogram(sr=SR, n_fft=NFFT, hop_length=HOP, n_mels=NMEL, mode='mel')
        self.x_max, self.x_min = 9.7666, -36.0437

        self.init_temp = self.model.temperature
        self.min_temp = self.model.min_temperature
        self.decay_rate = self.model.decay_rate

        self.pseudo_train = config['trainer']['pseudo_train']
        self.labeled = config['trainer']['labeled']
        self.labeled_sample = np.random.choice(data_loader.sampler.indices, size=int(len(data_loader.sampler.indices) * self.labeled), replace=False)

        self.freeze_encoder = config['trainer']['freeze_encoder']
        self.pitch_shift = config['trainer']['pitch_shift']

    def data_transform(self, x, **kwargs):
        def get_idx(at_time=0.2, pitch_shift=2):
            compensate_duration = 0.05
            load_duration = at_time + compensate_duration  # add 0.05s more after the targeted time instant
            # desired_idx = int(at_time * SR)
            if pitch_shift != 0:
                pitch_shift = np.random.randint(-pitch_shift, pitch_shift)
            # shift = -2
            scale = 2. ** (pitch_shift / 12.)
            idx_comp = int(compensate_duration * scale**(-1) * SR / HOP)  # the corresponding number of indices to be compensated
            if pitch_shift < 0:
                n_sample = int(scale**(-1) * load_duration * SR)
                # n_sample = int(scale**(-1) * load_duration * SR)
                # assert n_sample > desired_idx
                desired_idx = int((load_duration * SR) / HOP) - idx_comp

            if pitch_shift >= 0:
                n_sample = int(load_duration * SR)
                desired_idx = int((scale**(-1) * n_sample) / HOP) - idx_comp

            return pitch_shift, n_sample, desired_idx
        
        shift, n_sample, desired_idx = get_idx(**kwargs)
        
        x = LoadNpArray(n_sample=n_sample)(x)
        x = PitchShift(shift=shift)(x)

        x = ToTensor()(x)
        x = self.spec_ext(x)
        x = LogCompress()(x)
        x = Clipping(clip_min=self.x_min, clip_max=self.x_max)(x)
        x = MinMaxNorm(x_min=self.x_min, x_max=self.x_max)(x)
        x = x[:, :, desired_idx]
        return x, shift
   
    def get_gumb_temp(self, epoch, init_temp, min_temp, decay_rate):
        temp = np.maximum(init_temp * np.exp(-decay_rate * epoch), min_temp)
        return temp

    def get_ps_label(self, yp, ps):
        y_shift = torch.from_numpy(np.array(ps)).unsqueeze(-1).to(self.device)
        y_ps = yp + y_shift
        mask_l = torch.where(y_ps >= 0, torch.ones_like(y_ps), torch.zeros_like(y_ps))
        mask_u = torch.where(y_ps <= 81, torch.ones_like(y_ps), torch.zeros_like(y_ps))
        mask = mask_l * mask_u
        y_ps *= mask
        if self.pitch_shift == 0: assert (y_ps == yp).sum() == len(y_ps)
        return yp, y_ps, mask.float(), torch.ones_like(yp)

    def get_pseudo_label(self, logit, supervised_idx, pitch_label, pitch_shift):
        '''Algorithm for creating pseudo labels for pitch-shifted samples
        '''
        supervised = True if len(supervised_idx) > 0 else False
        # initialize masks for both original and pitch-shiftedd samples
        m, m_ps = torch.zeros_like(pitch_label).float(), torch.zeros_like(pitch_label).float()
        '''Original samples'''
        # pseudo labels are defined from the inferred catogrical distribution
        y_pseudo = torch.argmax(logit, dim=-1, keepdim=True)
        if supervised:
            supervised_idx = supervised_idx.long()
            # replace pseudo with supervised labels
            # NOTE: psuedo labels become true if supervised portion is 100%
            y_pseudo[supervised_idx] = pitch_label[supervised_idx]
            # only the supervised indices are un-masked for the orignal samples
            m[supervised_idx] = 1  # cross-entropy induced by pseudo labels will be masked

        '''Pitch-shifted samples'''
        # exploit pseudo labels if if
        if self.pseudo_train:
            m_ps += 1

        # un-mask supervised labels regardlessly
        if supervised:
            m_ps[supervised_idx] = 1

        if m_ps.gt(1).any(): print("mask has entry larger than 1 before being multiplied with exclusion mask")

        # further mask the out-of-range pitches based on pseudo labels
        _, y_ps_pseudo, m_ps_ext, _ = self.get_ps_label(y_pseudo, pitch_shift)
        m_ps *= m_ps_ext

        if m_ps.gt(1).any(): print("mask has entry larger than 1 AFTER being multiplied with exclusion mask")

        return y_pseudo, y_ps_pseudo, m, m_ps

    def get_data(self, x, n_semitone=2):
        for i, x_i in enumerate(x):
            x_ps, ps = self.data_transform(x_i, at_time=0.2, pitch_shift=n_semitone)
            x_ori, _ = self.data_transform(x_i, at_time=0.2, pitch_shift=0)
            if i == 0:
                ps_cat = [ps]
                x_ps_cat = x_ps
                x_ori_cat = x_ori
            else:
                ps_cat.append(ps)
                x_ps_cat= torch.cat([x_ps_cat, x_ps])
                x_ori_cat = torch.cat([x_ori_cat, x_ori])

        return x_ori_cat, x_ps_cat, ps_cat

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        torch.manual_seed(1111)
        for batch_idx, (x, idx, y) in enumerate(self.data_loader):
            supervised_idx = torch.from_numpy(np.array([i for i, v in enumerate(idx.numpy()) 
                                 if v in self.labeled_sample], dtype='float')).to(self.device)

            y = torch.stack(y, dim=1).to(self.device)
            yp = y[:, 1:2]
            x1, x2, ps_cat = self.get_data(x, n_semitone=self.pitch_shift)

            self.optimizer.zero_grad()
            x1_hat, h1, mu1, logvar1, z1_t, z1_p, logits1, prob1 = self.model(x1, yp)
            if self.model.gumbel:
                y_pseudo, y_ps_pseudo, m, m_ps = self.get_pseudo_label(logits1, supervised_idx, yp, ps_cat)
            else:
                y_pseudo, y_ps_pseudo, m, m_ps = self.get_ps_label(yp, ps_cat)

            x2_hat, h2, mu2, logvar2, z2_t, z2_p, logits2, prob2 = self.model(x2, y_ps_pseudo)

            # con_loss = self.nt_xent_criterion(mu1, mu2)

            dict_loss = self.criterion(self.model, self.pseudo_train, self.device,
                            x1, x1_hat, x2, x2_hat,
                            mu1, logvar1, z1_t, z1_p, 
                            mu2, logvar2, z2_t, z2_p, 
                            logits1=logits1, logits2=logits2, prob1=prob1, prob2=prob2,
                            epoch=epoch, mask=m_ps.float(), mask_y=m.float(),
                            y=y_pseudo.squeeze(-1), y_ps=y_ps_pseudo.squeeze(-1))

            for k, v in dict_loss.items():
                if torch.isnan(v): print(k)

            for name, p in self.model.named_parameters():
                if torch.isnan(p).any(): print(name)

            if dict_loss['cycle'].requires_grad:
                dict_loss['loss'].backward(retain_graph=True)
            else:
                dict_loss['loss'].backward()
            self.optimizer.step()
            pre_tim_op = copy.deepcopy(list(self.model.timbre_encoder.parameters()))
            pre_pitch_op = copy.deepcopy(list(self.model.pitch_encoder.parameters()))

            if dict_loss['cycle'].requires_grad:
                self.optimizer.zero_grad()
                dict_loss['cycle'].backward()
                if self.freeze_encoder:
                    for i, param in enumerate(self.model.timbre_encoder.parameters()):
                        param.grad[:] = 0
                    for i, param in enumerate(self.model.pitch_encoder.parameters()):
                        if param.grad is not None: param.grad[:] = 0

                self.optimizer.step()

            for name, p in self.model.named_parameters():
                if torch.isnan(p).any(): print(name)

            if self.model.gumbel:
                temp = self.get_gumb_temp(epoch, self.init_temp, self.min_temp, self.decay_rate)
                self.model.set_temperature(temp)
            else:
                temp = 0

            for track, output in zip(self.track_loss, dict_loss):
                assert track == output
                log_val = dict_loss[track].item()
                self.train_metrics.update(track, log_val)

            if batch_idx == self.len_epoch:
                break

            if batch_idx == 0:
                idx_cat = idx
                zt_cat, zp_cat = z1_t, z1_p
                yt_cat, yp_cat, yf_cat, yc_cat, yd_cat = y[:, 0:1], y[:, 1:2], y[:, -1:], y[:, 2:3], y[:, 3:4]
                x_cat, x_hat_cat = x1, x1_hat
                mu1_cat, logvar1_cat = mu1, logvar1
                mu2_cat, logvar2_cat = mu2, logvar2
                if prob1 is not None:
                    yp_hat_cat = torch.argmax(prob1, dim=-1, keepdim=True)
                else:
                    yp_hat_cat = None
            else:
                idx_cat = torch.cat([idx_cat, idx])
                zt_cat, zp_cat = torch.cat([zt_cat, z1_t]), torch.cat([zp_cat, z1_p])
                yt_cat = torch.cat([yt_cat, y[:, 0:1]], dim=0)
                yp_cat = torch.cat([yp_cat, y[:, 1:2]], dim=0)
                yf_cat = torch.cat([yf_cat, y[:, -1:]], dim=0)
                yc_cat = torch.cat([yc_cat, y[:, 2:3]], dim=0)
                yd_cat = torch.cat([yd_cat, y[:, 3:4]], dim=0)
                mu1_cat, logvar1_cat = torch.cat([mu1_cat, mu1]), torch.cat([logvar1_cat, logvar1])
                mu2_cat, logvar2_cat = torch.cat([mu2_cat, mu2]), torch.cat([logvar2_cat, logvar2])
                x_hat_cat = torch.cat([x_hat_cat, x1_hat])
                x_cat = torch.cat([x_cat, x1])
                if prob1 is not None:
                    yp_hat_cat = torch.cat([yp_hat_cat, torch.argmax(prob1, dim=-1, keepdim=True)])
                else:
                    yp_hat_cat = None


        self.writer.set_step(epoch, 'train')
        for track, output in zip(self.track_loss, dict_loss):
            assert track == output
            self.writer.add_scalar(track, self.train_metrics.avg(track))
        for met in self.metric_ftns:
            # if met.__name__ == 'cluster_var':
            #     self.train_metrics.update(met.__name__, met(mu1_cat.cpu(), yp_cat.cpu()))
            # if met.__name__ == 'kl_gauss':
            #     self.train_metrics.update(met.__name__, met(mu1_cat, logvar1_cat, mu2_cat, logvar2_cat).item())
            if met.__name__ == 'f1' and yp_hat_cat is not None:
                self.train_metrics.update(met.__name__, met(yp_hat_cat, yp_cat, n_class=82).item())
            if met.__name__ == 'cluster_acc' and yp_hat_cat is not None:
                self.train_metrics.update(met.__name__, met(yp_hat_cat, yp_cat))
            if met.__name__ == 'nmi' and yp_hat_cat is not None:
                self.train_metrics.update(met.__name__, met(yp_hat_cat, yp_cat))

            self.writer.add_scalar(met.__name__, self.train_metrics.avg(met.__name__))

        log = self.train_metrics.result()

        if epoch % self.plot_step == 0:
            yt_cat = yt_cat.squeeze(-1).detach().cpu().numpy()
            yp_cat = yp_cat.squeeze(-1).detach().cpu().numpy()
            yf_cat = yf_cat.squeeze(-1).detach().cpu().numpy()
            yc_cat = yc_cat.squeeze(-1).detach().cpu().numpy()
            yd_cat = yd_cat.squeeze(-1).detach().cpu().numpy()
            zt_2d = TSNE(n_components=2).fit_transform(mu1_cat.cpu().data.numpy())
            fig, ax = plt.subplots(2, 4, figsize=(4*5, 2*5))

            def plot_and_color(data, ax, label_map, labels, colors=None):
                n_class = len(np.unique(labels))
                if colors is not None:
                    assert n_class == len(colors)
                else:
                    random.seed(1111)
                    colors = ['#'+''.join([random.choice('0123456789ABCDEF') for j in range(6)]) for i in range(n_class)]
                assert len(label_map.items()) == n_class
                for k, v in label_map.items():
                    target_data = data[labels == v]
                    ax.scatter(target_data[:, 0], target_data[:, 1], c=colors[v], label=k, alpha=0.7)

            plot_and_color(zt_2d, ax[0][0], INSTRUMENT_MAP, yt_cat, colors=INSTRUMENT_COLORS)
            plot_and_color(zt_2d, ax[0][2], self.pitch_map, yp_cat, colors=PITCH_COLORS)
            plot_and_color(zt_2d, ax[0][1], FAMILY_MAP, yf_cat, colors=None)
            plot_and_color(zt_2d, ax[0][3], self.dynamic_map, yd_cat, colors=None)
            ax[1][1].imshow(self.model.emb.weight.cpu().data.numpy().T, aspect='auto', origin='lower')

        else:
            fig = None 
            ax = None

        if self.do_validation:
            val_log = self._valid_epoch(epoch, fig, ax)
            log.update(**{'val_'+k : v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        log.update({"gumbel_temp": temp})

        return log

    def _valid_epoch(self, epoch, fig=None, ax=None):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        torch.manual_seed(1111)
        with torch.no_grad():
            for batch_idx, (x, idx, y) in enumerate(self.valid_data_loader):
                y = torch.stack(y, dim=1).to(self.device)
                yp = y[:, 1:2]
                x1, x2, ps_cat = self.get_data(x, n_semitone=self.pitch_shift)

                x1_hat, h1, mu1, logvar1, z1_t, z1_p, logits1, prob1 = self.model(x1, yp)
                y_pseudo, y_ps_pseudo, m, m_ps = self.get_ps_label(yp, ps_cat)

                x2_hat, h2, mu2, logvar2, z2_t, z2_p, logits2, prob2 = self.model(x2, y_ps_pseudo)

                x1_hat_swap, _ = self.model.decode(z2_t, z1_p)
                x2_hat_swap, _ = self.model.decode(z1_t, z2_p)

                dict_loss = self.criterion(self.model, self.pseudo_train, self.device,
                                x1, x1_hat, x2, x2_hat,
                                mu1, logvar1, z1_t, z1_p, 
                                mu2, logvar2, z2_t, z2_p, 
                                logits1=logits1, logits2=logits2, prob1=prob1, prob2=prob2,
                                epoch=epoch, mask=m_ps.float(), mask_y=m.float(),
                                y=y_pseudo.squeeze(-1), y_ps=y_ps_pseudo.squeeze(-1))

                for track, output in zip(self.track_loss, dict_loss):
                    assert track == output
                    log_val = dict_loss[track].item()
                    self.valid_metrics.update(track, log_val)

                if batch_idx == 0:
                    idx_cat = idx
                    zt_cat, zp_cat = z1_t, z1_p
                    yt_cat, yp_cat, yf_cat, yc_cat, yd_cat = y[:, 0:1], y[:, 1:2], y[:, -1:], y[:, 2:3], y[:, 3:4]
                    x_cat, x_hat_cat = x1, x1_hat
                    mu1_cat, logvar1_cat = mu1, logvar1
                    mu2_cat, logvar2_cat = mu2, logvar2
                    h_cat = h1
                    if prob1 is not None:
                        yp_hat_cat = torch.argmax(prob1, dim=-1, keepdim=True)
                    else:
                        yp_hat_cat = None
                else:
                    idx_cat = torch.cat([idx_cat, idx])
                    zt_cat, zp_cat = torch.cat([zt_cat, z1_t]), torch.cat([zp_cat, z1_p])
                    yt_cat = torch.cat([yt_cat, y[:, 0:1]], dim=0)
                    yp_cat = torch.cat([yp_cat, y[:, 1:2]], dim=0)
                    yf_cat = torch.cat([yf_cat, y[:, -1:]], dim=0)
                    yc_cat = torch.cat([yc_cat, y[:, 2:3]], dim=0)
                    yd_cat = torch.cat([yd_cat, y[:, 3:4]], dim=0)
                    mu1_cat, logvar1_cat = torch.cat([mu1_cat, mu1]), torch.cat([logvar1_cat, logvar1])
                    mu2_cat, logvar2_cat = torch.cat([mu2_cat, mu2]), torch.cat([logvar2_cat, logvar2])
                    x_hat_cat = torch.cat([x_hat_cat, x1_hat])
                    x_cat = torch.cat([x_cat, x1])
                    h_cat = torch.cat([h_cat, h1])
                    if prob1 is not None:
                        yp_hat_cat = torch.cat([yp_hat_cat, torch.argmax(prob1, dim=-1, keepdim=True)])
                    else:
                        yp_hat_cat = None

        self.writer.set_step(epoch, 'valid')
        for track, output in zip(self.track_loss, dict_loss):
            assert track == output
            self.writer.add_scalar(track, self.valid_metrics.avg(track))
        for met in self.metric_ftns:
            # if met.__name__ == 'cluster_var':
            #     self.valid_metrics.update(met.__name__, met(mu1_cat.cpu(), yp_cat.cpu()))
            # if met.__name__ == 'kl_gauss':
            #     self.valid_metrics.update(met.__name__, met(mu1_cat, logvar1_cat, mu2_cat, logvar2_cat).item())
            if met.__name__ == 'f1' and yp_hat_cat is not None:
                self.valid_metrics.update(met.__name__, met(yp_hat_cat, yp_cat, n_class=82).item())
            if met.__name__ == 'cluster_acc' and yp_hat_cat is not None:
                self.valid_metrics.update(met.__name__, met(yp_hat_cat, yp_cat))
            if met.__name__ == 'nmi' and yp_hat_cat is not None:
                self.valid_metrics.update(met.__name__, met(yp_hat_cat, yp_cat))

            self.writer.add_scalar(met.__name__, self.valid_metrics.avg(met.__name__))
        
        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')

        if fig is not None:
            idx_cat = idx_cat.squeeze(-1).cpu().data.numpy()
            target_idx = np.array([np.where(idx_cat == i)[0] for i in self.recon_sample])
            non_empty_idx = np.vstack([(n, i) for n, i in enumerate(target_idx) if len(i) == 1])
            target_idx = np.vstack([i for i in target_idx if len(i) == 1])[:,0]
            self.recon_sample = [self.recon_sample[i[0]] for i in non_empty_idx]
            # target_idx = np.array([np.where(idx_cat == i)[0] for i in self.recon_sample])[:,0]
            target_pitch = np.array([self.sample_to_pitch[i] for i in self.recon_sample])
            origin = x_cat.cpu().data.numpy()[target_idx]
            output = x_hat_cat.cpu().data.numpy()[target_idx]
            h_cat = h_cat.cpu().data.numpy()[target_idx]
            zt_cat = zt_cat[target_idx]
            zp_target = self.model.emb.weight[target_pitch]
            if self.model.use_hp:
                zp_target = self.model.project_harmonic(zp_target)            
            output_pswap = self.model.decode(zt_cat, zp_target)[0]
            output_pswap = output_pswap.cpu().data.numpy()
            for m, (i, j, k, l) in enumerate(zip(origin, output, h_cat, output_pswap)):
                tmp= np.vstack([i, j])
                tmp_swap = np.vstack([i, l])
                if self.model.decoding == 'sf':
                    tmp_h = np.vstack([i, k])
                if m == 0:
                    pair = tmp
                    pair_swap = tmp_swap
                    if self.model.decoding == 'sf':
                        pair_h = tmp_h
                else:
                    pair = np.vstack([pair, tmp])
                    pair_swap = np.vstack([pair_swap, tmp_swap])
                    if self.model.decoding == 'sf':
                        pair_h = np.vstack([pair_h, tmp_h])
                      
            ax[1][2].imshow(pair.T, aspect='auto', origin='lower', vmin=0, vmax=1)
            for l in range(1, 2*len(self.recon_sample), 2):
                ax[1][2].axvline(x=l+0.5, lw=1.5, c='r')

            ax[1][3].imshow(pair_swap.T, aspect='auto', origin='lower', vmin=0, vmax=1)
            for l in range(1, 2*len(self.recon_sample), 2):
                ax[1][3].axvline(x=l+0.5, lw=1.5, c='r')

            if self.model.decoding == 'sf':
                ax[1][0].imshow(pair_h.T, aspect='auto', origin='lower', vmin=0, vmax=1)
                for l in range(1, 2*len(self.recon_sample), 2):
                    ax[1][0].axvline(x=l+0.5, lw=1.5, c='r')

            self.writer.set_step(epoch, 'train')
            self.writer.add_figure('tsne', fig)
            
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Exemple #5
0
class ClassifierTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        torch.manual_seed(1111)
        for batch_idx, (x, idx, gt) in enumerate(self.data_loader):
            x = x.squeeze(1)
            gt = torch.stack(gt, dim=1).to(self.device)

            self.optimizer.zero_grad()
            if self.model.target == 'instrument':
                y = gt[:, 0]
            elif self.model.target == 'pitch':
                y = gt[:, 1]

            output = self.model(x)
            loss = self.criterion(output, y)

            loss.backward()
            self.optimizer.step()

            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output, y))

            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()

        self.writer.set_step(epoch, 'train')
        self.writer.add_scalar('loss', self.train_metrics.avg('loss'))
        for met in self.metric_ftns:
            self.writer.add_scalar(met.__name__,
                                   self.train_metrics.avg(met.__name__))

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        torch.manual_seed(1111)
        with torch.no_grad():
            for batch_idx, (x, idx, gt) in enumerate(self.valid_data_loader):
                x = x.squeeze(1)
                gt = torch.stack(gt, dim=1).to(self.device)

                if self.model.target == 'instrument':
                    y = gt[:, 0]
                elif self.model.target == 'pitch':
                    y = gt[:, 1]
                output = self.model(x)

                loss = self.criterion(output, y)

                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(output, y))

        self.writer.set_step(epoch, 'valid')
        self.writer.add_scalar('loss', self.valid_metrics.avg('loss'))
        for met in self.metric_ftns:
            self.writer.add_scalar(met.__name__,
                                   self.valid_metrics.avg(met.__name__))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
class BaseTrainer:
    """
    Base class for all trainers
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 lr_scheduler,
                 config,
                 trainloader,
                 validloader=None,
                 len_epoch=None):
        self.config = config
        self.logger = config.get_logger('trainer',
                                        config['trainer']['verbosity'])
        self.trainloader = trainloader
        self.validloader = validloader

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.trainloader)
        else:
            # iteration-based training
            self.trainloader = inf_loop(trainloader)
            self.len_epoch = len_epoch

        # setup GPU device if available, move model into configured device
        n_gpu_use = torch.cuda.device_count()
        self.device = torch.device('cuda:0' if n_gpu_use > 0 else 'cpu')
        self.model = model.to(self.device)
        self.model = torch.nn.DataParallel(model)

        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.lr_scheduler = lr_scheduler

        cfg_trainer = config['trainer']
        self.epochs = cfg_trainer['epochs']
        self.log_step = cfg_trainer['log_step']
        self.save_period = cfg_trainer['save_period']
        self.monitor = cfg_trainer.get('monitor', 'off')

        self.start_epoch = 1
        self.checkpoint_dir = config.save_dir

        # configuration to monitor model performance and save best
        if self.monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = self.monitor.split()
            assert self.mnt_mode in ['min', 'max']

            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = cfg_trainer.get('early_stop', inf)

        # setup visualization writer instance
        self.writer = TensorboardWriter(config.log_dir, self.logger,
                                        cfg_trainer['tensorboard'])
        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

        if config.resume is not None:
            self._resume_checkpoint(config.resume)

    @abstractmethod
    def _train_step(self, batch):
        """
        Training logic for a step

        :param batch: batch of current step
        :return: 
            loss: torch Variable with map for backwarding
            mets: metrics computed between output and target, dict
        """
        raise NotImplementedError

    @abstractmethod
    def _valid_step(self, batch):
        """
        Valid logic for a step

        :param batch: batch of current step
        :return:
            loss: torch Variable without map
            mets: metrics computed between output and target, dict
        """
        raise NotImplementedError

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()

        tic = time.time()
        datatime = batchtime = 0
        for batch_idx, batch in enumerate(self.trainloader):
            datatime += time.time() - tic
            # -------------------------------------------------------------------------
            loss, mets = self._train_step(batch)
            # -------------------------------------------------------------------------
            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            batchtime += time.time() - tic
            tic = time.time()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for key, val in mets.items():
                self.train_metrics.update(key, val)

            if batch_idx % self.log_step == 0:
                processed_percent = batch_idx / self.len_epoch * 100
                self.logger.debug(
                    'Train Epoch:{} [{}/{}]({:.0f}%)\tTime:{:5.2f}/{:<5.2f}\tLoss:({:.4f}){:.4f}'
                    .format(epoch, batch_idx, self.len_epoch,
                            processed_percent, datatime, batchtime,
                            loss.item(), self.train_metrics.avg('loss')))
                datatime = batchtime = 0

            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()
        log = {'train_' + k: v for k, v in log.items()}

        if self.validloader is not None:
            val_log = self._valid_epoch(epoch)
            log.update(**{'valid_' + k: v for k, v in val_log.items()})
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        for batch_idx, batch in enumerate(self.validloader):
            # -------------------------------------------------------------------------
            loss, mets = self._valid_step(batch)
            # -------------------------------------------------------------------------
            self.writer.set_step(
                (epoch - 1) * len(self.validloader) + batch_idx, 'valid')
            self.valid_metrics.update('loss', loss.item())
            for key, val in mets.items():
                self.valid_metrics.update(key, val)

        return self.valid_metrics.result()

    def train(self):
        """
        Full training logic
        """
        not_improved_count = 0
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            # save logged informations into log dict
            lr = self.optimizer.param_groups[0]['lr']
            log = {'epoch': epoch, 'lr': lr}
            log.update(result)

            # print logged informations to the screen
            for key, value in log.items():
                self.logger.info('    {:20s}: {}'.format(str(key), value))

            # evaluate model performance according to configured metric, save best checkpoint as model_best
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] <= self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] >= self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False

                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1

                if not_improved_count > self.early_stop:
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if self.lr_scheduler is not None:
                if isinstance(self.lr_scheduler, ReduceLROnPlateau):
                    self.lr_scheduler.step(log[self.mnt_metric])
                else:
                    self.lr_scheduler.step()

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, save_best=best)

            # add histogram of model parameters to the tensorboard
            self.writer.set_step(epoch)
            for name, p in self.model.named_parameters():
                self.writer.add_histogram(name, p, bins='auto')

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        state = {
            'epoch': epoch,
            'model': self.model.module.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'lr_scheduler': self.lr_scheduler.state_dict(),
            'monitor_best': self.mnt_best
        }
        filename = str(self.checkpoint_dir / 'chkpt_{:03d}.pth'.format(epoch))
        torch.save(state, filename)
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def _resume_checkpoint(self, resume_path):
        """
        Resume from saved checkpoints

        :param resume_path: Checkpoint path to be resumed
        """
        resume_path = str(resume_path)
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        try:
            self.start_epoch = checkpoint['epoch'] + 1
            self.model.module.load_state_dict(checkpoint['model'])
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
            self.mnt_best = checkpoint['monitor_best']
        except KeyError:
            self.model.module.load_state_dict(checkpoint)

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
Exemple #7
0
class Trainer(BaseTrainer):
    def __init__(self, config):
        super(Trainer, self).__init__(config)
        self.datamanager = DataManger(config["data"])

        # model
        self.model = Baseline(
            num_classes=self.datamanager.datasource.get_num_classes("train")
        )

        # summary model
        summary(
            self.model,
            input_size=(3, 256, 128),
            batch_size=config["data"]["batch_size"],
            device="cpu",
        )

        # losses
        cfg_losses = config["losses"]
        self.criterion = Softmax_Triplet_loss(
            num_class=self.datamanager.datasource.get_num_classes("train"),
            margin=cfg_losses["margin"],
            epsilon=cfg_losses["epsilon"],
            use_gpu=self.use_gpu,
        )

        self.center_loss = CenterLoss(
            num_classes=self.datamanager.datasource.get_num_classes("train"),
            feature_dim=2048,
            use_gpu=self.use_gpu,
        )

        # optimizer
        cfg_optimizer = config["optimizer"]
        self.optimizer = torch.optim.Adam(
            self.model.parameters(),
            lr=cfg_optimizer["lr"],
            weight_decay=cfg_optimizer["weight_decay"],
        )

        self.optimizer_centerloss = torch.optim.SGD(
            self.center_loss.parameters(), lr=0.5
        )

        # learing rate scheduler
        cfg_lr_scheduler = config["lr_scheduler"]
        self.lr_scheduler = WarmupMultiStepLR(
            self.optimizer,
            milestones=cfg_lr_scheduler["steps"],
            gamma=cfg_lr_scheduler["gamma"],
            warmup_factor=cfg_lr_scheduler["factor"],
            warmup_iters=cfg_lr_scheduler["iters"],
            warmup_method=cfg_lr_scheduler["method"],
        )

        # track metric
        self.train_metrics = MetricTracker("loss", "accuracy")
        self.valid_metrics = MetricTracker("loss", "accuracy")

        # save best accuracy for function _save_checkpoint
        self.best_accuracy = None

        # send model to device
        self.model.to(self.device)

        self.scaler = GradScaler()

        # resume model from last checkpoint
        if config["resume"] != "":
            self._resume_checkpoint(config["resume"])

    def train(self):
        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)

            if self.lr_scheduler is not None:
                self.lr_scheduler.step()

            result = self._valid_epoch(epoch)

            # add scalars to tensorboard
            self.writer.add_scalars(
                "Loss",
                {
                    "Train": self.train_metrics.avg("loss"),
                    "Val": self.valid_metrics.avg("loss"),
                },
                global_step=epoch,
            )
            self.writer.add_scalars(
                "Accuracy",
                {
                    "Train": self.train_metrics.avg("accuracy"),
                    "Val": self.valid_metrics.avg("accuracy"),
                },
                global_step=epoch,
            )

            # logging result to console
            log = {"epoch": epoch}
            log.update(result)
            for key, value in log.items():
                self.logger.info("    {:15s}: {}".format(str(key), value))

            # save model
            if (
                self.best_accuracy == None
                or self.best_accuracy < self.valid_metrics.avg("accuracy")
            ):
                self.best_accuracy = self.valid_metrics.avg("accuracy")
                self._save_checkpoint(epoch, save_best=True)
            else:
                self._save_checkpoint(epoch, save_best=False)

            # save logs
            self._save_logs(epoch)

    def _train_epoch(self, epoch):
        """Training step"""
        self.model.train()
        self.train_metrics.reset()
        with tqdm(total=len(self.datamanager.get_dataloader("train"))) as epoch_pbar:
            epoch_pbar.set_description(f"Epoch {epoch}")
            for batch_idx, (data, labels, _) in enumerate(
                self.datamanager.get_dataloader("train")
            ):
                # push data to device
                data, labels = data.to(self.device), labels.to(self.device)

                # zero gradient
                self.optimizer.zero_grad()
                self.optimizer_centerloss.zero_grad()

                with autocast():
                    # forward batch
                    score, feat = self.model(data)

                    # calculate loss and accuracy
                    loss = (
                        self.criterion(score, feat, labels)
                        + self.center_loss(feat, labels) * self.config["losses"]["beta"]
                    )
                    _, preds = torch.max(score.data, dim=1)

                # backward parameters
                # loss.backward()
                self.scaler.scale(loss).backward()

                # backward parameters for center_loss
                for param in self.center_loss.parameters():
                    param.grad.data *= 1.0 / self.config["losses"]["beta"]

                # optimize
                # self.optimizer.step()
                self.scaler.step(self.optimizer)
                self.optimizer_centerloss.step()

                self.scaler.update()

                # update loss and accuracy in MetricTracker
                self.train_metrics.update("loss", loss.item())
                self.train_metrics.update(
                    "accuracy",
                    torch.sum(preds == labels.data).double().item() / data.size(0),
                )

                # update process bar
                epoch_pbar.set_postfix(
                    {
                        "train_loss": self.train_metrics.avg("loss"),
                        "train_acc": self.train_metrics.avg("accuracy"),
                    }
                )
                epoch_pbar.update(1)
        return self.train_metrics.result()

    def _valid_epoch(self, epoch):
        """Validation step"""
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            with tqdm(total=len(self.datamanager.get_dataloader("val"))) as epoch_pbar:
                epoch_pbar.set_description(f"Epoch {epoch}")
                for batch_idx, (data, labels, _) in enumerate(
                    self.datamanager.get_dataloader("val")
                ):
                    # push data to device
                    data, labels = data.to(self.device), labels.to(self.device)

                    with autocast():
                        # forward batch
                        score, feat = self.model(data)

                        # calculate loss and accuracy
                        loss = (
                            self.criterion(score, feat, labels)
                            + self.center_loss(feat, labels)
                            * self.config["losses"]["beta"]
                        )
                        _, preds = torch.max(score.data, dim=1)

                    # update loss and accuracy in MetricTracker
                    self.valid_metrics.update("loss", loss.item())
                    self.valid_metrics.update(
                        "accuracy",
                        torch.sum(preds == labels.data).double().item() / data.size(0),
                    )

                    # update process bar
                    epoch_pbar.set_postfix(
                        {
                            "val_loss": self.valid_metrics.avg("loss"),
                            "val_acc": self.valid_metrics.avg("accuracy"),
                        }
                    )
                    epoch_pbar.update(1)
        return self.valid_metrics.result()

    def _save_checkpoint(self, epoch, save_best=True):
        """save model to file"""
        state = {
            "epoch": epoch,
            "state_dict": self.model.state_dict(),
            "center_loss": self.center_loss.state_dict(),
            "optimizer": self.optimizer.state_dict(),
            "optimizer_centerloss": self.optimizer_centerloss.state_dict(),
            "lr_scheduler": self.lr_scheduler.state_dict(),
            "best_accuracy": self.best_accuracy,
        }
        filename = os.path.join(self.checkpoint_dir, "model_last.pth")
        self.logger.info("Saving last model: model_last.pth ...")
        torch.save(state, filename)
        if save_best:
            filename = os.path.join(self.checkpoint_dir, "model_best.pth")
            self.logger.info("Saving current best: model_best.pth ...")
            torch.save(state, filename)

    def _resume_checkpoint(self, resume_path):
        """Load model from checkpoint"""
        if not os.path.exists(resume_path):
            raise FileExistsError("Resume path not exist!")
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=self.map_location)
        self.start_epoch = checkpoint["epoch"] + 1
        self.model.load_state_dict(checkpoint["state_dict"])
        self.center_loss.load_state_dict(checkpoint["center_loss"])
        self.optimizer.load_state_dict(checkpoint["optimizer"])
        self.optimizer_centerloss.load_state_dict(checkpoint["optimizer_centerloss"])
        self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
        self.best_accuracy = checkpoint["best_accuracy"]
        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch)
        )

    def _save_logs(self, epoch):
        """Save logs from google colab to google drive"""
        if os.path.isdir(self.logs_dir_saved):
            shutil.rmtree(self.logs_dir_saved)
        destination = shutil.copytree(self.logs_dir, self.logs_dir_saved)
Exemple #8
0
class ClassificationTrainer(LayerwiseTrainer):
    def __init__(self,
                 model,
                 criterions,
                 metric_ftns,
                 optimizer,
                 config,
                 train_data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 weight_scheduler=None,
                 test_data_loader=None):
        super().__init__(model, criterions, metric_ftns, optimizer, config,
                         train_data_loader, valid_data_loader, lr_scheduler,
                         weight_scheduler)
        self.train_teacher_metrics = MetricTracker(
            *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            'supervised_loss',
            'kd_loss',
            'hint_loss',
            'teacher_loss',
            *[m.__name__ for m in self.metric_ftns],
            *['teacher_' + m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.test_data_loader = test_data_loader

    def _train_epoch(self, epoch):
        self.prepare_train_epoch(epoch)

        self.model.train()
        self._clean_cache()

        for batch_idx, (data, target) in enumerate(self.train_data_loader):
            data, target = data.to(self.device), target.to(self.device)

            output_st, output_tc = self.model(data)

            supervised_loss = self.criterions[0](
                output_st, target) / self.accumulation_steps
            kd_loss = self.criterions[1](output_st,
                                         output_tc) / self.accumulation_steps

            hint_loss = reduce(
                lambda acc, elem: acc + self.criterions[2](elem[0], elem[1]),
                zip(self.model.student_hidden_outputs,
                    self.model.teacher_hidden_outputs),
                torch.tensor(0)) / self.accumulation_steps
            teacher_loss = self.criterions[0](output_tc,
                                              target)  # for comparision

            # Only use hint loss
            loss = kd_loss
            loss.backward()

            if (batch_idx + 1) % self.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            # update metrics
            self.train_metrics.update('loss',
                                      loss.item() * self.accumulation_steps)
            self.train_metrics.update(
                'supervised_loss',
                supervised_loss.item() * self.accumulation_steps)
            self.train_metrics.update('kd_loss',
                                      kd_loss.item() * self.accumulation_steps)
            self.train_metrics.update(
                'hint_loss',
                hint_loss.item() * self.accumulation_steps)
            self.train_metrics.update('teacher_loss', teacher_loss.item())

            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output_st, target),
                                          data.shape[0])

            for met in self.metric_ftns:
                self.train_teacher_metrics.update(met.__name__,
                                                  met(output_tc, target),
                                                  data.shape[0])

            if batch_idx % self.log_step == 0:
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                self.logger.info(
                    'Train Epoch: {} [{}]/[{}] acc: {:.6f} teacher_acc: {:.6f} Loss: {:.6f} Supervised Loss: {:.6f} '
                    'Knowledge Distillation loss: {:.6f} Hint Loss: {:.6f} Teacher Loss: {:.6f}'
                    .format(
                        epoch,
                        batch_idx,
                        self.len_epoch,
                        self.train_metrics.avg('accuracy'),
                        self.train_teacher_metrics.avg('accuracy'),
                        self.train_metrics.avg('loss'),
                        self.train_metrics.avg('supervised_loss'),
                        self.train_metrics.avg('kd_loss'),
                        self.train_metrics.avg('hint_loss'),
                        self.train_metrics.avg('teacher_loss'),
                    ))

            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()

        if self.do_validation and ((epoch % self.do_validation_interval) == 0):
            # clean cache to prevent out-of-memory with 1 gpu
            self._clean_cache()
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if (self.lr_scheduler is not None) and (not isinstance(
                self.lr_scheduler, MyOneCycleLR)):
            if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
                self.lr_scheduler.step(self.train_metrics.avg('loss'))
            else:
                self.lr_scheduler.step()

        self.weight_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output, output_tc = self.model(data)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output,
                                                  target), data.shape[0])
                for met in self.metric_ftns:
                    self.valid_metrics.update('teacher_' + met.__name__,
                                              met(output_tc, target),
                                              data.shape[0])

        return self.valid_metrics.result()