Beispiel #1
0
    def test(self, test_loader):
        self.model.eval()
        test_metrics = MetricTracker(
            'test_loss',
            *['test_' + m.__name__ for m in self.metric_ftns],
            writer=self.writer)

        with torch.no_grad():
            for batch_idx, batch in enumerate(test_loader):
                input_variables, input_lengths = getattr(batch, 'source')
                target = getattr(batch, 'target')
                output, _, sequence_info = self.model.forward(
                    input=input_variables,
                    input_lens=input_lengths,
                    target=target)
                loss = self.criterion.__call__(output, target)

                # set writer step
                self.writer.set_step(
                    (self.epochs - 1) * len(self.valid_data_loader) +
                    batch_idx, 'test')

                # set val metrics
                test_metrics.update('test_loss', loss.item())
                for metric in self.metric_ftns:
                    test_metrics.update('test_' + metric.__name__,
                                        metric(output, target, sequence_info))

        for name, p in self.model.named_parameters():
            self.writer.add_histogram('test_' + name, p, bins='auto')
        for key, value in test_metrics.result().items():
            self.logger.info('    {:15s}: {}'.format(str(key), value))
Beispiel #2
0
    def forward(self, items: collections.Mapping, tracker: MetricTracker,
                storage: SharedStorage):
        losses = []
        for loss_cls in self.losses:
            loss = loss_cls(items, tracker, storage)
            losses.append(loss)
            if self.items_len_key:
                n = len(items[self.items_len_key])
            else:
                n = 1
            tracker.update(loss_cls.__name__, loss.item(), n=n)

        if self.weights:
            total = 0
            for loss, weight in zip(losses, self.weights):
                total += loss * weight
            return total
        else:
            return sum(losses)
Beispiel #3
0
    def forward(self, items: collections.Mapping, tracker: MetricTracker,
                storage: SharedStorage):
        features, target = items[self.output_key], items[self.target_key]
        n = len(features)

        batch_size = features.size(0)
        same_target = (target.eq(target.view(batch_size, 1)))
        norms = features.square().sum(dim=1, keepdim=True).expand(
            batch_size, batch_size)
        distmat = norms + norms.t()  # a^2 + b^2
        distmat.addmm_(beta=1, alpha=-2, mat1=features,
                       mat2=features.T)  # a^2 + b^2 - 2ab
        distmat = distmat.clamp(min=1e-12).sqrt()  # euclid

        pos_dists = distmat[same_target & ~torch.diagflat(
            torch.ones(n, dtype=torch.bool, device=distmat.device))]
        neg_dists = distmat[~same_target]
        hard_pos_dists, _ = distmat[same_target].view(batch_size,
                                                      -1).max(dim=1)
        hard_neg_dists, _ = neg_dists.view(batch_size, -1).min(dim=1)
        if self.track_distances:
            pos_mean = hard_pos_dists.mean()
            neg_mean = hard_neg_dists.mean()
            tracker.append_histogram("batch_hard_dist_ap", hard_pos_dists)
            tracker.append_histogram("batch_hard_dist_an", hard_neg_dists)
            tracker.append_histogram("batch_dist_pos", pos_dists)
            tracker.append_histogram("batch_dist_neg", neg_dists)
            tracker.append_histogram("batch_hard_delta",
                                     hard_pos_dists - hard_neg_dists)
            tracker.update("batch_hard_dist_ap_mean", pos_mean, n=n)
            tracker.update("batch_hard_dist_an_mean", neg_mean, n=n)
            tracker.update("batch_hard_dist_ap_mean",
                           pos_dists.mean().item(),
                           n=n)
            tracker.update("batch_hard_dist_an_mean",
                           neg_dists.mean().item(),
                           n=n)
            tracker.update("batch_hard_delta_mean", pos_mean - neg_mean, n=n)

        if self.margin:
            loss = self.ranking_loss(
                hard_neg_dists, hard_pos_dists,
                torch.ones(batch_size,
                           device=features.device,
                           dtype=features.dtype))
        else:
            loss = self.ranking_loss(
                hard_neg_dists - hard_pos_dists,
                torch.ones(batch_size,
                           device=features.device,
                           dtype=features.dtype))
        return loss
Beispiel #4
0
class Tester(BaseTester):
    """
    Tester class
    """
    def __init__(self, model, criterion, metric_ftns, plot_ftns, config, data_loader):
        super().__init__(model, criterion, metric_ftns, plot_ftns, config)
        self.config = config
        self.data_loader = data_loader

        self.test_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns])

    def _test(self):
        """
        Test logic

        :return: A log that contains information about testing
        """
        self.model.eval()
        self.test_metrics.reset()
        with torch.no_grad():
            outputs = []
            targets = []
            for batch_idx, (data, target) in enumerate(tqdm(self.data_loader)):
                data, target = data.to(self.device, non_blocking=self.non_blocking), target.to(self.device, non_blocking=self.non_blocking)

                output = self.model(data)
                loss = self.criterion(output, target)

                outputs.append(output)
                targets.append(target)

                self.test_metrics.update('loss', loss.item())

            outputs = torch.cat(outputs)
            targets = torch.cat(targets)
            for met in self.metric_ftns:
                self.test_metrics.update(met.__name__, met(outputs, targets))
            for plt in self.plot_ftns:
                image_path = self.config.log_dir / (plt.__name__ + '.png')
                torchvision.utils.save_image(plt(outputs, targets).float(), image_path, normalize=True)

        return self.test_metrics.result()
Beispiel #5
0
    def on_epoch_end(self, tracker: MetricTracker, storage: SharedStorage):
        if self.train:
            qf = gf = storage.get_data("features")
            qpids = gpids = storage.get_data("pids")
            qcamids = gcamids = storage.get_data("camids")
            prefix = ""
        else:
            qf = storage.get_data("qf")
            gf = storage.get_data("gf")
            qpids = storage.get_data("qpids")
            gpids = storage.get_data("gpids")
            qcamids = storage.get_data("qcamids")
            gcamids = storage.get_data("gcamids")
            prefix = "valid_"

        distmat = storage.get_data("distmat")
        if distmat is None:
            distmat = compute_distances(qf, gf)
            storage.set_data("distmat", distmat)
        same_pid = gpids.eq(qpids.reshape(-1, 1))
        same_cam = gcamids.eq(qcamids.reshape(-1, 1))
        negative: torch.Tensor = ~same_pid
        positive: torch.Tensor = same_pid
        positive_same_cam = torch.logical_and(same_pid, same_cam)
        positive_diff_cam = torch.logical_and(same_pid, ~same_cam)

        if self.train:
            # filter out identical instances from positive distances
            same_image = torch.diagflat(torch.ones(qf.size(0), dtype=torch.bool, device=qf.device))
            positive.logical_and_(~same_image)
            positive_same_cam.logical_and_(~same_image)

        tracker.update(prefix + "global_dist_pos_same_cam_mean", distmat[positive_same_cam].mean().item())
        tracker.update(prefix + "global_dist_pos_diff_cam_mean", distmat[positive_diff_cam].mean().item())
        tracker.update(prefix + "global_dist_pos_mean", distmat[positive].mean().item())
        tracker.update(prefix + "global_dist_neg_mean", distmat[negative].mean().item())
        tracker.append_histogram(prefix + "global_dist_pos_same_cam", distmat[positive_same_cam])
        tracker.append_histogram(prefix + "global_dist_pos_diff_cam", distmat[positive_diff_cam])
        tracker.append_histogram(prefix + "global_dist_pos", distmat[positive])
        tracker.append_histogram(prefix + "global_dist_neg", distmat[negative])
Beispiel #6
0
class Eval:
    def __init__(self, models, criterion, metrics, device):
        self.criterion = criterion
        self.models = models
        self.device = device
        self.metrics = metrics
        self.valid_metrics = MetricTracker('loss',
                                           *[m.__name__ for m in self.metrics],
                                           writer=None)
        self.logger = logging.getLogger()

    def eval(self, valid_data_loader):
        for model in self.models:
            model.eval()
        self.valid_metrics.reset()
        outputs = []
        targets = []
        with torch.no_grad():
            tk = tqdm(enumerate(valid_data_loader),
                      total=len(valid_data_loader))
            for batch_idx, (data, target) in tk:
                data, target = data.to(self.device), target.to(self.device)
                for model in self.models:
                    output = model(data)
                    output2 = model(data.flip(-1))
                    loss = self.criterion(output, target)

                    outputs.append(
                        (output.sigmoid().detach().cpu().numpy() +
                         output2.sigmoid().detach().cpu().numpy()) / 2)
                    targets.append(target.cpu().numpy())

                self.valid_metrics.update('loss', loss.item())
                tk.set_description("loss: %.6f" % loss.item())
        outputs = np.concatenate(outputs)
        targets = np.concatenate(targets)
        for met in self.metrics:
            self.valid_metrics.update(met.__name__, met(outputs, targets))
        self.logger.info(self.valid_metrics.result())
        return self.valid_metrics.result()
Beispiel #7
0
    def on_epoch_end(self, tracker: MetricTracker, storage: SharedStorage):
        qpids = storage.get_data("qpids")
        gpids = storage.get_data("gpids")
        qcamids = storage.get_data("qcamids")
        gcamids = storage.get_data("gcamids")

        distmat = storage.get_data("distmat")
        if distmat is None:
            qf = storage.get_data("qf")
            gf = storage.get_data("gf")
            distmat = compute_distances(qf, gf)
            storage.set_data("distmat", distmat)

        all_cmc, all_AP, all_INP = evaluate(distmat, qpids, gpids, qcamids, gcamids)
        r1 = all_cmc[0].item()
        mAP = all_AP.mean().item()
        mINP = all_INP.mean().item()

        tracker.update("r1", r1)
        tracker.update("mAP", mAP)
        tracker.update("mINP", mINP)
Beispiel #8
0
class Trainer(BaseTrainer):
	"""
	Trainer class
	"""
	def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
				 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
		super().__init__(model, criterion, metric_ftns, optimizer, config)
		self.config = config
		self.data_loader = data_loader
		if len_epoch is None:
			# epoch-based training
			self.len_epoch = len(self.data_loader)
		else:
			# iteration-based training
			self.data_loader = inf_loop(data_loader)
			self.len_epoch = len_epoch
		self.valid_data_loader = valid_data_loader
		self.do_validation = self.valid_data_loader is not None
		self.lr_scheduler = lr_scheduler
		self.log_step = int(np.sqrt(data_loader.batch_size))

		self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
		self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

		if hasattr(self.data_loader, 'n_valid_samples'):
			validation_samples=self.data_loader.n_valid_samples
		else:
			validation_samples=self.valid_data_loader.n_samples
		self.heatmap_sample_indices=np.sort(np.random.randint(validation_samples, size=min(16, validation_samples)))

	def _train_epoch(self, epoch):
		"""
		Training logic for an epoch

		:param epoch: Integer, current training epoch.
		:return: A log that contains average loss and metric in this epoch.
		"""
		self.model.train()
		self.train_metrics.reset()
		for batch_idx, (data, target) in enumerate(self.data_loader):
			data, target = data.to(self.device), target.to(self.device)

			self.optimizer.zero_grad()
			output = self.model(data)
			loss = self.criterion(output, target)
			loss.backward()
			self.optimizer.step()

			self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
			self.train_metrics.update('loss', loss.item())
			for met in self.metric_ftns:
				self.train_metrics.update(met.__name__, met(output, target))

			if batch_idx % self.log_step == 0:
				self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
					epoch,
					self._progress(batch_idx),
					loss.item()))
				# self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

			if batch_idx == self.len_epoch:
				break
		log = self.train_metrics.result()

		if self.do_validation:
			val_log = self._valid_epoch(epoch)
			log.update(**{'val_'+k : v for k, v in val_log.items()})

		if self.lr_scheduler is not None:
			self.lr_scheduler.step()
		return log

	def _valid_epoch(self, epoch):
		"""
		Validate after training an epoch

		:param epoch: Integer, current training epoch.
		:return: A log that contains information about validation
		"""
		self.model.eval()
		self.valid_metrics.reset()
		y=[]
		y_hat=[]
		with torch.no_grad():
			for batch_idx, (data, target) in enumerate(self.valid_data_loader):
				data, target = data.to(self.device), target.to(self.device)

				output = self.model(data)
				loss = self.criterion(output, target)

				self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
				self.valid_metrics.update('loss', loss.item())
				for met in self.metric_ftns:
					self.valid_metrics.update(met.__name__, met(output, target))
				# self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

				y.append(target.cpu().numpy())
				y_hat.append(output.detach().cpu().numpy())
		y=np.concatenate(y)
		y_hat=np.concatenate(y_hat)
		self._do_validation_visualizations(epoch, y, y_hat)

		# # add histogram of model parameters to the tensorboard
		# for name, p in self.model.named_parameters():
		#     self.writer.add_histogram(name, p, bins='auto')

		return self.valid_metrics.result()

	def _do_validation_visualizations(self, epoch, y, y_hat):
		# show multi class AUC-ROC
		roc=get_auc_roc_curve(y, y_hat, len(self.data_loader.dataset.classes), labels=self.data_loader.dataset.classes)
		self.writer.add_figure("metric/roc", roc)
		# show confusion_matrix
		cm=get_confusion_matrix_figure(y, y_hat, len(self.data_loader.dataset.classes), labels=self.data_loader.dataset.classes)
		self.writer.add_figure("metric/confusion_matrix", cm)
		self._do_heatmaps(epoch)
		return

	def _do_heatmaps(self, epoch):
		images=[]
		targets=[]
		for idx in self.heatmap_sample_indices:
			img,label = self.valid_data_loader.dataset.__getitem__(idx)
			images.append(img)
			targets.append(label)
		gradcam, gradcam_pp = get_heatmap_tensors(images, self.model, self.config,
												  self.checkpoint_dir, epoch,
												  save_images_to_dir=True)
		self.writer.add_images("saliency/gradcam", gradcam)
		self.writer.add_images("saliency/gradcam_pp", gradcam_pp)

		return

	def _progress(self, batch_idx):
		base = '[{}/{} ({:.0f}%)]'
		if hasattr(self.data_loader, 'n_samples'):
			current = batch_idx * self.data_loader.batch_size
			total = self.data_loader.n_samples
		else:
			current = batch_idx
			total = self.len_epoch
		return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_inference = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.test_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, data in enumerate(self.data_loader):
            self.optimizer.zero_grad()
            input_ids, attention_masks, text_lengths, labels = data

            if 'cuda' == self.device.type:
                input_ids = input_ids.cuda()
                if attention_masks is not None:
                    attention_masks = attention_masks.cuda()
                text_lengths = text_lengths.cuda()
                labels = labels.cuda()
            preds, embedding = self.model(input_ids, attention_masks,
                                          text_lengths)
            preds = preds.squeeze()
            loss = self.criterion[0](preds, labels)
            loss.backward()
            self.optimizer.step()
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())

            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(preds, labels))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.3f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()
        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.do_inference:
            test_log = self._inference_epoch(epoch)
            log.update(**{'test_' + k: v for k, v in test_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, data in enumerate(self.valid_data_loader):
                input_ids, attention_masks, text_lengths, labels = data

                if 'cuda' == self.device.type:
                    input_ids = input_ids.cuda()
                    if attention_masks is not None:
                        attention_masks = attention_masks.cuda()
                    text_lengths = text_lengths.cuda()
                    labels = labels.cuda()
                preds, embedding = self.model(input_ids, attention_masks,
                                              text_lengths)
                preds = preds.squeeze()
                if self.add_graph:
                    input_model = self.model.module if (len(
                        self.config.config['device_id']) > 1) else self.model
                    self.writer.writer.add_graph(
                        input_model,
                        [input_ids, attention_masks, text_lengths])
                    self.add_graph = False
                loss = self.criterion[0](preds, labels)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(preds, labels))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _inference_epoch(self, epoch):
        """
        Inference after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.test_metrics.reset()
        with torch.no_grad():
            for batch_idx, data in enumerate(self.test_data_loader):
                input_ids, attention_masks, text_lengths, labels = data

                if 'cuda' == self.device.type:
                    input_ids = input_ids.cuda()
                    if attention_masks is not None:
                        attention_masks = attention_masks.cuda()
                    text_lengths = text_lengths.cuda()
                    labels = labels.cuda()
                preds, embedding = self.model(input_ids, attention_masks,
                                              text_lengths)
                preds = preds.squeeze()
                loss = self.criterion[0](preds, labels)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'test')
                self.test_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.test_metrics.update(met.__name__, met(preds, labels))

                # add histogram of model parameters to the tensorboard
            for name, p in self.model.named_parameters():
                self.writer.add_histogram(name, p, bins='auto')
            return self.test_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #10
0
class Trainer(BaseTrainer):
    """
    Class implementation for trainers.
    The class is inherited from the class BaseTrainer.
    """
    def __init__(self,
                 model,
                 criterion,
                 metricFunction,
                 optimizer,
                 configuration,
                 device,
                 dataLoader,
                 validationDataLoader=None,
                 learningRateScheduler=None,
                 epochLength=None):
        """
        Method to initialize an object of type Trainer.

        Parameters
        ----------
        self                    : Trainer
                                  Instance of the class
        model                   : torch.nn.Module
                                  Model to be trained
        criterion               : callable
                                  Criterion to be evaluated (This is usually the loss function to be minimized)
        metricFunction          : callable
                                  Metric functions to evaluate model performance
        optimizer               : torch.optim
                                  Optimizer to be used during training
        device                  : torch.device
                                  Device on which the training would be performed
        dataLoader              : torch.utils.data.DataLoader
                                  Dataset sampler to load training data for model training
        validationDataLoader    : torch.utils.data.DataLoader
                                  Dataset sampler to load validation data for model validation (Default value: None)
        learningRateScheduler   : torch.optim.lr_scheduler
                                  Method to adjust learning rate (Default value: None)
        epochLength             : int
                                  Total number of epochs for training (Default value: None)

        Returns
        -------
        self    : Trainer
                  Initialized object of class Trainer
        """
        # Initialize BaseTrainer class
        super().__init__(model, criterion, metricFunction, optimizer,
                         configuration)

        # Save trainer configuration, device, dataLoaders, learningRateScheduler and loggingStep
        self.configuration = configuration
        self.device = device
        self.dataLoader = dataLoader
        if epochLength is None:
            self.epochLength = len(self.dataLoader)
        else:
            self.dataLoader = infinte_loop(dataLoader)
            self.epochLength = epochLength
        self.validationDataLoader = validationDataLoader
        self.performValidation = (self.validationDataLoader is not None)
        self.learningRateScheduler = learningRateScheduler
        self.loggingStep = int(np.sqrt(dataLoader.batch_size))

        # Set up training and validation metrics
        self.trainingMetrics = MetricTracker(
            "loss",
            *[
                individualMetricFunction.__name__
                for individualMetricFunction in self.metricFunction
            ],
            writer=self.writer)
        self.validationMetrics = MetricTracker(
            "loss",
            *[
                individualMetricFunction.__name__
                for individualMetricFunction in self.metricFunction
            ],
            writer=self.writer)

    def train_epoch(self, epoch):
        """
        Method to train a single epoch.

        Parameters
        ----------
        self    : Trainer
                  Instance of the class
        epoch   : int
                  Current epoch number

        Returns
        -------
        log     : dict
                  Average of all the metrics in a dictionary
        """
        # Set the model to training mode and start training the model
        self.model.train()
        self.trainingMetrics.reset()
        print(type(self.dataLoader) is data_loader.data_loaders.JaadDataLoader)
        for batchId, (data, target) in enumerate(self.dataLoader):
            print(1)
            data, target = data.to(self.device), target.to(self.device)
            print(2)

            self.optimizer.zero_grad()
            print(3)
            output = self.model(data)
            print(4)
            loss = self.criteria(output, target)
            print(5)
            loss.backward()
            print(6)
            self.optimizer.step()
            print(7)

            # Update training metrics
            self.writer.set_step((epoch - 1) * self.epochLength + batchId)
            print(8)
            self.trainingMetrics.update("loss", loss.item())
            print(9)
            for individualMetric in self.metricFunction:
                self.trainingMetrics.update(individualMetric.__name__,
                                            individualMetric(output, target))

            print(10)
            if batchId % self.loggingStep == 0:
                self.logger.debug("Training Epoch: {} {} Loss: {}".format(
                    epoch, self.progress(batchId), loss.item()))
                self.writer.add_image(
                    "input", make_grid(data.cpu(), nrow=8, normalize=True))

            print(11)
            if batchId == self.epochLength:
                break

        print(12)
        log = self.trainingMetrics.result()

        print(13)
        if self.performValidation:
            validationLog = self.validate_epoch(epoch)
            log.update(
                **
                {"val_" + key: value
                 for key, value in validationLog.items()})

        print(14)
        if self.learningRateScheduler is not None:
            self.learningRateScheduler.step()

        return log

    def validate_epoch(self, epoch):
        """
        Method to validate a single epoch.

        Parameters
        ----------
        self    : Trainer
                  Instance of the class
        epoch   : int
                  Current epoch number

        Returns
        -------
        log     : dict
                  Average of all the metrics in a dictionary
        """
        # Set the model to evaluation mode and start validating the model
        self.model.eval()
        self.validationMetrics.reset()

        with torch.no_grad():
            for batchId, (data,
                          target) in enumerate(self.validationDataLoader):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                # Update training metrics
                self.writer.set_step(
                    (epoch - 1) * len(self.validationDataLoader) + batchId,
                    "valid")
                self.validationMetrics.update("loss", loss.item())
                for individualMetric in self.metricFunction:
                    self.validationMetrics.update(
                        individualMetric.__name__,
                        individualMetric(output, target))
                self.writer.add_image(
                    "input", make_grid(data.cpu(), nrow=8, normalize=True))

        # Update TensorBoardWriter
        for name, parameter in self.model.named_parameters():
            self.writer.add_histogram(name, parameter, bins="auto")

        return self.validationMetrics.result()

    def progress(self, batchId):
        """
        Method to calculate progress of training or validation.

        Parameters
        ----------
        self    : Trainer
                  Instance of the class
        batchId : int
                  Current batch ID

        Returns
        -------
        progress    : str
                      Amount of progress
        """
        base = "[{}/{} ({:.0f}%)]"

        if hasattr(self.dataLoader, "numberOfSamples"):
            current = batchId * self.dataLoader.batch_size
            total = self.dataLoader.numberOfSamples
        else:
            current = batchId
            total = self.epochLength

        return base.format(current, total, 100.0 * current / total)
Beispiel #11
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 overfit_single_batch=False):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader if not overfit_single_batch else None
        self.test_data_loader = test_data_loader if not overfit_single_batch else None
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.overfit_single_batch = overfit_single_batch

        # -------------------------------------------------
        # add flexibility to allow no metric in config.json
        self.log_loss = ['loss', 'nll', 'kl']
        if self.metric_ftns is None:
            self.train_metrics = MetricTracker(*self.log_loss,
                                               writer=self.writer)
            self.valid_metrics = MetricTracker(*self.log_loss,
                                               writer=self.writer)
        # -------------------------------------------------
        else:
            self.train_metrics = MetricTracker(
                *self.log_loss,
                *[m.__name__ for m in self.metric_ftns],
                writer=self.writer)
            self.valid_metrics = MetricTracker(
                *self.log_loss,
                *[m.__name__ for m in self.metric_ftns],
                writer=self.writer)
            self.test_metrics = MetricTracker(
                *[m.__name__ for m in self.metric_ftns], writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()

        # ----------------
        # add logging grad
        dict_grad = {}
        for name, p in self.model.named_parameters():
            if p.requires_grad and 'bias' not in name:
                dict_grad[name] = np.zeros(self.len_epoch)
        # ----------------

        for batch_idx, batch in enumerate(self.data_loader):
            x, x_reversed, x_mask, x_seq_lengths = batch

            x = x.to(self.device)
            x_reversed = x_reversed.to(self.device)
            x_mask = x_mask.to(self.device)
            x_seq_lengths = x_seq_lengths.to(self.device)

            self.optimizer.zero_grad()
            x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \
                self.model(x, x_reversed, x_seq_lengths)
            kl_annealing_factor = \
                determine_annealing_factor(self.config['trainer']['min_anneal_factor'],
                                           self.config['trainer']['anneal_update'],
                                           epoch - 1, self.len_epoch, batch_idx)
            kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, loss = \
                self.criterion(x, x_recon, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq, kl_annealing_factor, x_mask)
            loss.backward()

            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)
            # ------------
            # accumulate gradients that are to be logged later after epoch ends
            for name, p in self.model.named_parameters():
                if p.requires_grad and 'bias' not in name:
                    val = 0 if p.grad is None else p.grad.abs().mean()
                    dict_grad[name][batch_idx] = val
            # ------------

            self.optimizer.step()

            for l_i, l_i_val in zip(self.log_loss, [loss, nll_m, kl_m]):
                self.train_metrics.update(l_i, l_i_val.item())
            if self.metric_ftns is not None:
                for met in self.metric_ftns:
                    if met.__name__ == 'bound_eval':
                        self.train_metrics.update(
                            met.__name__,
                            met([x_recon, mu_q_seq, logvar_q_seq],
                                [x, mu_p_seq, logvar_p_seq],
                                mask=x_mask))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

            if batch_idx == self.len_epoch or self.overfit_single_batch:
                break

        # ---------------------------------------------------
        if self.writer is not None:
            self.writer.set_step(epoch, 'train')
            # log losses
            for l_i in self.log_loss:
                self.train_metrics.write_to_logger(l_i)
            # log metrics
            if self.metric_ftns is not None:
                if met.__name__ == 'bound_eval':
                    self.train_metrics.write_to_logger(met.__name__)
            # log gradients
            for name, p in dict_grad.items():
                self.writer.add_histogram(name + '/grad', p, bins='auto')
            # log parameters
            for name, p in self.model.named_parameters():
                self.writer.add_histogram(name, p, bins='auto')
            # log kl annealing factors
            self.writer.add_scalar('anneal_factor', kl_annealing_factor)
        # ---------------------------------------------------

        if epoch % 50 == 0:
            fig = create_reconstruction_figure(x, torch.sigmoid(x_recon))
            # debug_fig = create_debug_figure(x, x_reversed, x_mask)
            # debug_fig_loss = create_debug_loss_figure(kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, x_mask)
            self.writer.set_step(epoch, 'train')
            self.writer.add_figure('reconstruction', fig)
            # self.writer.add_figure('debug', debug_fig)
            # self.writer.add_figure('debug_loss', debug_fig_loss)

        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.do_test and epoch % 50 == 0:
            test_log = self._test_epoch(epoch)
            log.update(**{'test_' + k: v for k, v in test_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, batch in enumerate(self.valid_data_loader):
                x, x_reversed, x_mask, x_seq_lengths = batch

                x = x.to(self.device)
                x_reversed = x_reversed.to(self.device)
                x_mask = x_mask.to(self.device)
                x_seq_lengths = x_seq_lengths.to(self.device)

                x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \
                    self.model(x, x_reversed, x_seq_lengths)
                kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, loss = \
                    self.criterion(x, x_recon, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq, 1, x_mask)

                for l_i, l_i_val in zip(self.log_loss, [loss, nll_m, kl_m]):
                    self.valid_metrics.update(l_i, l_i_val.item())
                if self.metric_ftns is not None:
                    for met in self.metric_ftns:
                        if met.__name__ == 'bound_eval':
                            self.valid_metrics.update(
                                met.__name__,
                                met([x_recon, mu_q_seq, logvar_q_seq],
                                    [x, mu_p_seq, logvar_p_seq],
                                    mask=x_mask))

        # ---------------------------------------------------
        if self.writer is not None:
            self.writer.set_step(epoch, 'valid')
            for l_i in self.log_loss:
                self.valid_metrics.write_to_logger(l_i)
            if self.metric_ftns is not None:
                for met in self.metric_ftns:
                    if met.__name__ == 'bound_eval':
                        self.valid_metrics.write_to_logger(met.__name__)
        # ---------------------------------------------------

        if epoch % 10 == 0:
            x_recon = torch.nn.functional.sigmoid(
                x_recon.view(x.size(0), x.size(1), -1))
            fig = create_reconstruction_figure(x, x_recon)
            # debug_fig = create_debug_figure(x, x_reversed_unpack, x_mask)
            # debug_fig_loss = create_debug_loss_figure(kl_raw, nll_raw, kl_fr, nll_fr, kl_m, nll_m, x_mask)
            self.writer.set_step(epoch, 'valid')
            self.writer.add_figure('reconstruction', fig)
            # self.writer.add_figure('debug', debug_fig)
            # self.writer.add_figure('debug_loss', debug_fig_loss)

        return self.valid_metrics.result()

    def _test_epoch(self, epoch):
        self.model.eval()
        self.test_metrics.reset()
        with torch.no_grad():
            for batch_idx, batch in enumerate(self.test_data_loader):
                x, x_reversed, x_mask, x_seq_lengths = batch

                x = x.to(self.device)
                x_reversed = x_reversed.to(self.device)
                x_mask = x_mask.to(self.device)
                x_seq_lengths = x_seq_lengths.to(self.device)

                x_recon, z_q_seq, z_p_seq, mu_q_seq, logvar_q_seq, mu_p_seq, logvar_p_seq = \
                    self.model(x, x_reversed, x_seq_lengths)

                if self.metric_ftns is not None:
                    for met in self.metric_ftns:
                        if met.__name__ == 'bound_eval':
                            self.test_metrics.update(
                                met.__name__,
                                met([x_recon, mu_q_seq, logvar_q_seq],
                                    [x, mu_p_seq, logvar_p_seq],
                                    mask=x_mask))
                        if met.__name__ == 'importance_sample':
                            self.test_metrics.update(
                                met.__name__,
                                met(batch_idx,
                                    self.model,
                                    x,
                                    x_reversed,
                                    x_seq_lengths,
                                    x_mask,
                                    n_sample=500))
        # ---------------------------------------------------
        if self.writer is not None:
            self.writer.set_step(epoch, 'test')
            if self.metric_ftns is not None:
                for met in self.metric_ftns:
                    self.test_metrics.write_to_logger(met.__name__)

            n_sample = 3
            output_seq, z_p_seq, mu_p_seq, logvar_p_seq = self.model.generate(
                n_sample, 100)
            output_seq = torch.sigmoid(output_seq)
            plt.close()
            fig, ax = plt.subplots(n_sample, 1, figsize=(10, n_sample * 10))
            for i in range(n_sample):
                ax[i].imshow(output_seq[i].T.cpu().detach().numpy(),
                             origin='lower')
            self.writer.add_figure('generation', fig)
        # ---------------------------------------------------
        return self.test_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #12
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 loss_fn_class,
                 loss_fn_domain,
                 metric_ftns,
                 optimizer,
                 config,
                 device,
                 data_loader_source,
                 valid_data_loader_source=None,
                 data_loader_target=None,
                 valid_data_loader_target=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, metric_ftns, optimizer, config)
        self.config = config
        self.device = device
        self.loss_fn_class = loss_fn_class
        self.loss_fn_domain = loss_fn_domain
        self.data_loader_source = data_loader_source
        self.valid_data_loader_source = valid_data_loader_source
        self.data_loader_target = data_loader_target
        self.valid_data_loader_target = valid_data_loader_target
        self.model.to(self.device)

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = min(len(self.data_loader_source),
                                 len(self.data_loader_target))
        else:
            # FIXME: implement source/target style training or remove this feature
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        # FIXME: handle validation round
        self.valid_data_loader = valid_data_loader_source
        self.do_validation = self.valid_data_loader is not None

        self.lr_scheduler = lr_scheduler
        self.log_step = 64

        self.train_metrics = MetricTracker(
            'loss',
            'class_loss',
            'domain_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            'class_loss',
            'domain_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        # Setting model into train mode, required_grad
        self.model.train()
        # Reset all metric in metric dataframe
        self.train_metrics.reset()
        batch_idx = 0
        for source, target in zip(self.data_loader_source,
                                  self.data_loader_target):

            # source, target = source.to(self.device), target.to(self.device)

            # Calculate training progress and GRL λ
            p = float(batch_idx + (epoch-1) * self.len_epoch) / \
                (self.epochs * self.len_epoch)
            λ = 2. / (1. + np.exp(-10 * p)) - 1

            # === Train on source domain
            X_source, y_source = source
            X_source, y_source = X_source.to(self.device), y_source.to(
                self.device)

            # generate source domain labels: 0
            y_s_domain = torch.zeros(X_source.shape[0], dtype=torch.float32)
            y_s_domain = y_s_domain.to(self.device)

            class_pred_source, domain_pred_source = self.model(X_source, λ)
            # source classification loss
            loss_s_label = self.loss_fn_class(class_pred_source.squeeze(),
                                              y_source)

            # Compress from tensor size batch*1*1*1 => batch
            domain_pred_source = torch.squeeze(domain_pred_source)
            loss_s_domain = self.loss_fn_domain(
                domain_pred_source, y_s_domain)  # source domain loss (via GRL)

            # === Train on target domain
            X_target, _ = target
            # generate source domain labels: 0
            y_t_domain = torch.ones(X_target.shape[0], dtype=torch.float32)
            X_target = X_target.to(self.device)
            y_t_domain = y_t_domain.to(self.device)
            _, domain_pred_target = self.model(X_target, λ)

            domain_pred_target = torch.squeeze(domain_pred_target)
            loss_t_domain = self.loss_fn_domain(
                domain_pred_target, y_t_domain)  # source domain loss (via GRL)

            # === Optimizer ====

            self.optimizer.zero_grad()
            loss_s_label = torch.log(loss_s_label + 1e-9)
            loss = loss_t_domain + loss_s_domain + loss_s_label

            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            self.train_metrics.update('class_loss', loss_s_label.item())
            self.train_metrics.update('domain_loss', loss_s_domain.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__,
                                          met(class_pred_source, y_source))

            if batch_idx % self.log_step == 0:
                self.logger.debug(
                    f'Train Epoch: {epoch} {self._progress(batch_idx)} Loss: {loss.item():.4f} Source class loss: {loss_s_label.item():3f} Source domain loss {loss_s_domain.item():3f}'
                )
                self.writer.add_image(
                    'input', make_grid(X_source.cpu(), nrow=4, normalize=True))

            batch_idx += 1
            if batch_idx == self.len_epoch:
                break
        # Average the accumulated result to log the result
        log = self.train_metrics.result()
        # update lambda value to metric tracker
        log["lambda"] = λ
        # Run validation after each epoch if validation dataloader is available.
        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        # Set model to evaluation mode, required_grad = False
        # disables dropout and has batch norm use the entire population statistics
        self.model.eval()
        # Reset validation metrics in dataframe for a new validation round
        self.valid_metrics.reset()

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                # ignore labmda value
                output, _ = self.model(data, 1)
                loss = self.loss_fn_class(output.squeeze(), target)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output, target))
                self.writer.add_image(
                    'input', make_grid(data.cpu(), nrow=4, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader_source, 'n_samples'):
            current = batch_idx * self.data_loader_source.batch_size
            total = self.data_loader_source.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #13
0
class MAMOTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 trainable_params,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.losses_num = len(self.criterion)
        self.max_empirical_losses = self._compute_max_expirical_losses()
        copsolver = AnalyticalSolver()
        self.common_descent_vector = MultiObjectiveCDV(
            copsolver=copsolver,
            max_empirical_losses=self.max_empirical_losses,
            normalized=True)
        self.trainable_params = trainable_params
        self.opt_losses = self.config['opt_losses']

        self.train_metrics = MetricTracker(
            'loss',
            'weighted_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            'weighted_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _compute_max_expirical_losses(self):
        max_losses = [0] * self.losses_num
        cnt = 0

        for batch_idx, (data, target, price) in enumerate(self.data_loader):
            data, target, price = data.to(self.device), target.to(
                self.device), price.to(self.device)
            cnt += 1

            output = self.model(data)

            for i in range(self.losses_num):
                l = self._cal_loss(self.criterion[i], output, target, price)
                max_losses[i] = (cnt - 1) / cnt * \
                    max_losses[i] + 1 / cnt * l.item()

        return max_losses

    def _cal_loss(self, c, output, target, price):
        para_nums = len(inspect.getargspec(c)[0])
        if para_nums == 2:
            return c(output, target.float())
        elif para_nums == 3:
            return c(output, target.float(), price)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        average_alpha = [0] * self.losses_num
        cnt = 0
        for batch_idx, (data, target, price) in enumerate(self.data_loader):
            cnt += 1
            data, target, price = Variable(data).to(
                self.device), Variable(target).to(
                    self.device), Variable(price).to(self.device)

            losses_computed = []
            if self.opt_losses == 0 or self.opt_losses == 1:
                output = self.model(data)
                for loss in self.criterion:
                    losses_computed.append(
                        self._cal_loss(loss, output, target, price))

                self.optimizer.zero_grad()
                losses_computed[self.opt_losses].backward()
                self.optimizer.step()
            else:
                # calculate the gradients
                gradients = []
                for i, loss in enumerate(self.criterion):
                    # forward pass
                    output = self.model(data)
                    # calculate loss
                    L = self._cal_loss(loss, output, target, price)
                    # zero gradient
                    self.optimizer.zero_grad()
                    # backward pass
                    L.backward()
                    # get gradient for correctness objective
                    gradients.append(self.optimizer.get_gradient())

                # calculate the losses
                # forward pass
                output = self.model(data)

                for i, loss in enumerate(self.criterion):
                    L = self._cal_loss(loss, output, target, price)
                    losses_computed.append(L)

                # get the final loss to compute the common descent vector
                final_loss, alphas = self.common_descent_vector.get_descent_vector(
                    losses_computed, gradients)

                # moving average alpha
                for i, alpha in enumerate(alphas):
                    average_alpha[i] = (cnt - 1) / cnt * \
                        average_alpha[i] + 1 / cnt * alpha

                # zero gradient
                self.optimizer.zero_grad()
                # backward pass
                final_loss.backward()
                # update parameters
                self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', losses_computed[0].item())
            self.train_metrics.update('weighted_loss',
                                      losses_computed[1].item())
            for met in self.metric_ftns:
                para_nums = len(inspect.getargspec(met)[0])
                if para_nums == 2:
                    self.train_metrics.update(met.__name__,
                                              met(output, target))
                elif para_nums == 3:
                    self.train_metrics.update(met.__name__,
                                              met(output, target, price))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx),
                    losses_computed[0].item()))

            if batch_idx == self.len_epoch:
                break

        if self.opt_losses == 0:
            print("Optimize only logloss")
        elif self.opt_losses == 1:
            print("Optimize only weighted logloss")
        else:
            print("Optimize both logloss and weighted logloss")
            print(average_alpha)
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target,
                            price) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion[0](output, target.float())

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                w_loss = self.criterion[1](output, target.float(), price)
                self.valid_metrics.update('weighted_loss', w_loss.item())
                for met in self.metric_ftns:
                    para_nums = len(inspect.getargspec(met)[0])
                    if para_nums == 2:
                        self.valid_metrics.update(met.__name__,
                                                  met(output, target))
                    elif para_nums == 3:
                        self.valid_metrics.update(met.__name__,
                                                  met(output, target, price))

        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
                 data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config, device)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (sentences, sentences_mask, strokes, strokes_mask) in enumerate(self.data_loader):

            # Moving input data to device
            sentences, sentences_mask = sentences.to(self.device), sentences_mask.to(self.device)
            strokes, strokes_mask = strokes.to(self.device), strokes_mask.to(self.device)

            # Compute the loss and perform an optimization step
            self.optimizer.zero_grad()

            if str(self.model).startswith('Unconditional'):
                output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                gaussian_params = self.model.compute_gaussian_parameters(output_network)
                loss = self.criterion(gaussian_params, strokes, strokes_mask)
                loss.backward()
                # Gradient clipping
                clip_grad_norm_(self.model.rnn_1.parameters(), 10)
                clip_grad_norm_(self.model.rnn_2.parameters(), 10)
                clip_grad_norm_(self.model.rnn_3.parameters(), 10)

            elif str(self.model).startswith('Conditional'):
                output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                gaussian_params = self.model.compute_gaussian_parameters(output_network)
                loss = self.criterion(gaussian_params, strokes, strokes_mask)
                loss.backward()
                # Gradient clipping
                clip_grad_norm_(self.model.rnn_1_with_gaussian_attention.lstm_cell.parameters(), 10)
                clip_grad_norm_(self.model.rnn_2.parameters(), 10)
                clip_grad_norm_(self.model.rnn_3.parameters(), 10)

            elif str(self.model).startswith('Seq2Seq'):
                output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                loss = self.criterion(output_network, sentences, sentences_mask)
                loss.backward()
                # Gradient clipping
                clip_grad_norm_(self.model.parameters(), 10)

            elif str(self.model).startswith('Graves'):
                output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                gaussian_params = self.model.compute_gaussian_parameters(output_network)
                loss = self.criterion(gaussian_params, strokes, strokes_mask)
                loss.backward()
                # Gradient clipping
                clip_grad_norm_(self.model.rnn_1.parameters(), 10)
                clip_grad_norm_(self.model.rnn_2_with_gaussian_attention.lstm_cell.parameters(), 10)
                clip_grad_norm_(self.model.rnn_3.parameters(), 10)

            else:
                NotImplementedError("Not a valid model name")

            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch,
                    self._progress(batch_idx),
                    loss.item()))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_'+k : v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (sentences, sentences_mask, strokes, strokes_mask) in enumerate(self.valid_data_loader):

                # Moving input data to device
                sentences, sentences_mask = sentences.to(self.device), sentences_mask.to(self.device)
                strokes, strokes_mask = strokes.to(self.device), strokes_mask.to(self.device)

                # Compute the loss
                if str(self.model).startswith('Unconditional'):
                    output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                    gaussian_params = self.model.compute_gaussian_parameters(output_network)
                    loss = self.criterion(gaussian_params, strokes, strokes_mask)

                elif str(self.model).startswith('Conditional'):
                    output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                    gaussian_params = self.model.compute_gaussian_parameters(output_network)
                    loss = self.criterion(gaussian_params, strokes, strokes_mask)

                elif str(self.model).startswith('Seq2Seq'):
                    output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                    loss = self.criterion(output_network, sentences, sentences_mask)

                elif str(self.model).startswith('Graves'):
                    output_network = self.model(sentences, sentences_mask, strokes, strokes_mask)
                    gaussian_params = self.model.compute_gaussian_parameters(output_network)
                    loss = self.criterion(gaussian_params, strokes, strokes_mask)

                else:
                    NotImplementedError("Not a valid model name")

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss.item())

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 i_fold,
                 data_loader,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.i_fold = i_fold
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader

        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.zero_grad()
        self.train_metrics.reset()
        adv_train = self.config.init_obj('adversarial_training',
                                         module_adversarial,
                                         model=self.model)
        K = 3
        for batch_idx, data in enumerate(self.data_loader):
            self.model.train()
            ids, texts, input_ids, attention_masks, text_lengths, labels = data

            if 'cuda' == self.device.type:
                input_ids = input_ids.cuda(self.device)
                attention_masks = attention_masks.cuda(self.device)
                labels = labels.cuda(self.device)

            preds, cls_embedding = self.model(input_ids, attention_masks,
                                              text_lengths)
            loss = self.criterion[0](preds, labels)
            # 损失截断
            loss_zeros = torch.zeros_like(loss)
            loss = torch.where(
                loss > float(self.config.config['loss']['loss_cut']), loss,
                loss_zeros)
            loss.backward()
            if self.config.config['trainer'][
                    'is_adversarial_training'] and self.config.config[
                        'adversarial_training']['type'] == 'FGM':  # 对抗训练
                adv_train.attack()
                adv_preds, adv_cls_embedding = self.model(
                    input_ids, attention_masks, text_lengths)
                adv_loss = self.criterion[0](adv_preds, labels)
                adv_loss.backward()
                adv_train.restore()
            elif self.config.config['trainer'][
                    'is_adversarial_training'] and self.config.config[
                        'adversarial_training']['type'] == 'PGD':
                adv_train.backup_grad()
                # 对抗训练
                for t in range(K):
                    adv_train.attack(is_first_attack=(
                        t == 0
                    ))  # 在embedding上添加对抗扰动, first attack时备份param.data
                    if t != K - 1:
                        self.model.zero_grad()
                    else:
                        adv_train.restore_grad()
                    adv_preds, adv_cls_embedding = self.model(
                        input_ids, attention_masks, text_lengths)
                    adv_loss = self.criterion[0](adv_preds, labels)
                    adv_loss.backward()  # 反向传播,并在正常的grad基础上,累加对抗训练的梯度
                adv_train.restore()  # 恢复embedding参数

            if self.config.config['trainer']['clip_grad']:  # 梯度截断
                torch.nn.utils.clip_grad_norm_(
                    self.model.parameters(),
                    self.config.config['trainer']['max_grad_norm'])
            self.optimizer.step()
            if self.lr_scheduler is not None:
                self.lr_scheduler.step()
            self.model.zero_grad()
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())

            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(preds, labels))

            if batch_idx % self.log_step == 0:
                self.logger.debug(
                    'Train Epoch: {} {} Loss: {:.3f} lr: {}'.format(
                        epoch, self._progress(batch_idx), loss.item(),
                        self.optimizer.param_groups[0]['lr']))
            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()
        if self.valid_data_loader:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()

        with torch.no_grad():
            for batch_idx, data in enumerate(self.valid_data_loader):
                ids, texts, input_ids, attention_masks, text_lengths, labels = data
                if 'cuda' == self.device.type:
                    input_ids = input_ids.cuda(self.device)
                    attention_masks = attention_masks.cuda(self.device)
                    labels = labels.cuda(self.device)
                preds, cls_embedding = self.model(input_ids, attention_masks,
                                                  text_lengths)

                if self.add_graph:
                    input_model = self.model.module if (len(
                        self.config.config['device_id']) > 1) else self.model
                    self.writer.writer.add_graph(
                        input_model,
                        [input_ids, attention_masks, text_lengths])
                    self.add_graph = False
                loss = self.criterion[0](preds, labels)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())

                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(preds, labels))

        log = self.valid_metrics.result()
        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return log

    def _inference(self):
        """
        Inference after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        checkpoint = torch.load(self.best_path)
        self.logger.info("load best mode {} ...".format(self.best_path))
        self.model.load_state_dict(checkpoint['state_dict'])
        self.model.eval()

        ps = []
        ls = []
        with torch.no_grad():
            for batch_idx, data in enumerate(self.valid_data_loader):
                ids, texts, input_ids, attention_masks, text_lengths, labels = data
                if 'cuda' == self.device.type:
                    input_ids = input_ids.cuda(self.device)
                    attention_masks = attention_masks.cuda(self.device)
                    labels = labels.cuda(self.device)
                preds, cls_embedding = self.model(input_ids, attention_masks,
                                                  text_lengths)
                ps.append(preds)
                ls.append(labels)

        ps = torch.cat(ps, dim=0)
        ls = torch.cat(ls, dim=0)
        acc = module_mertric.binary_accuracy(ps, ls)
        self.logger.info('\toverall   acc :{}'.format(acc))

        result_file = self.test_data_loader.dataset.data_dir.parent / 'result' / '{}-{}-{}-{}-{}.jsonl'.format(
            self.config.config['experiment_name'],
            self.test_data_loader.dataset.transformer_model,
            self.config.config['k_fold'], self.i_fold, acc)

        if not result_file.parent.exists():
            result_file.parent.mkdir()

        result_writer = result_file.open('w')

        with torch.no_grad():
            for batch_idx, data in enumerate(self.test_data_loader):
                ids, texts, input_ids, attention_masks, text_lengths, labels = data
                if 'cuda' == self.device.type:
                    input_ids = input_ids.cuda(self.device)
                    attention_masks = attention_masks.cuda(self.device)
                preds, cls_embedding = self.model(input_ids, attention_masks,
                                                  text_lengths)
                preds = torch.round(
                    torch.sigmoid(preds)).cpu().detach().numpy()
                for pred, item_id, text in zip(preds, ids, texts):
                    result_writer.write(
                        json.dumps(
                            {
                                "id": item_id,
                                "text": text,
                                "labels": int(pred)
                            },
                            ensure_ascii=False) + '\n')

            result_writer.close()
            self.logger.info('result saving to {}'.format(result_file))

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #16
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.init_lr = config['optimizer']['args']['lr']
        self.warm_up = config['trainer']['warm_up']

        self.train_metrics = MetricTracker(
            'loss', *[m.__name__ for m in self.metric_ftns])
        self.valid_metrics = MetricTracker(
            'loss', *[m.__name__ for m in self.metric_ftns])

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (data, target) in enumerate(self.data_loader):
            data, target = data.to(self.device), target.to(self.device)

            # Linear Learning Rate Warm-up
            full_batch_idx = ((epoch - 1) * len(self.data_loader) + batch_idx)
            if epoch - 1 < self.warm_up:
                for params in self.optimizer.param_groups:
                    params['lr'] = self.init_lr / (
                        self.warm_up * len(self.data_loader)) * full_batch_idx
            lr = get_lr(self.optimizer)

            # -------- TRAINING LOOP --------
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            # -------------------------------

            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output, target))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        log.update({'lr': lr})

        # Add log to WandB
        if not self.config['debug']:
            wandb.log(log)

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output, target))

        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #17
0
class TrainerRetrievalAux(BaseTrainerRetrieval):
    """
    Trainer class for retrieval with classification as extra info
    """
    def __init__(self, model, model_text, criterion, criterion_ret,
                 metric_ftns, optimizer, config,
                 data_loader, font_type,
                 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, model_text, criterion, criterion_ret, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        self.font_type = font_type
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.no_tasks = len(_FACTORS_IN_ORDER)
        list_metrics = []
        for m in self.metric_ftns:
            for i in range(0, self.no_tasks):
                metric_task = f"{m.__name__}_{_FACTORS_IN_ORDER[i]}"
                list_metrics.append(metric_task)
        list_losses = []
        for i in range(0, self.no_tasks):
            list_losses.append(f"loss_{_FACTORS_IN_ORDER[i]}")
        self.train_metrics = MetricTracker('loss_classification', 'accuracy_retrieval',
                                           'loss_floor_hue', 'loss_wall_hue', 'loss_object_hue',
                                           'loss_retrieval', 'loss_tot', 'loss_scale', 'loss_shape',
                                           'loss_orientation', 'accuracy_floor_hue',
                                           'accuracy_wall_hue', 'accuracy_object_hue',
                                           'accuracy_scale', 'accuracy_shape',
                                           'accuracy_orientation', 'accuracy',
                                           writer=self.writer)
        self.valid_metrics = MetricTracker('loss_classification', 'accuracy_retrieval',
                                           'loss_floor_hue', 'loss_wall_hue', 'loss_object_hue',
                                           'loss_retrieval', 'loss_tot', 'loss_scale', 'loss_shape',
                                           'loss_orientation', 'accuracy_floor_hue',
                                           'accuracy_wall_hue', 'accuracy_object_hue',
                                           'accuracy_scale', 'accuracy_shape',
                                           'accuracy_orientation', 'accuracy',
                                           writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model_text.train()
        self.model.train()
        self.train_metrics.reset()
        if epoch == 1:
            list_of_counters = []
            for i in range(0, 6):
                list_of_counters.append(Counter())
        for batch_idx, (data, target_ret, target_init) in enumerate(self.data_loader):
            # import pdb; pdb.set_trace()
            data, target_ret = data.to(self.device), target_ret.to(self.device)
            target_init = target_init.to(self.device)
            self.optimizer.zero_grad()
            text_output = self.model_text(target_ret.float())
            output_ret, output_init = self.model(data)
            loss_ret = self.criterion_ret(output_ret, text_output, 20)
            no_tasks = len(target_init[0])
            loss_classification = 0

            for i in range(0, no_tasks):
                output_task = output_init[i]
                target_task = target_init[:, i]
                if epoch == 1:
                    list_of_counters[i] += Counter(target_task.tolist())
                new_org = add_margin(img_list=data[0:8, :, :],
                                     labels=target_task,
                                     predictions=output_task,
                                     margins=5,
                                     idx2label=self.data_loader.idx2label_init[i],
                                     font=self.font_type,
                                    )
                self.writer.add_image(f"Image_train_marg_{_FACTORS_IN_ORDER[i]}_{epoch}",
                                      torchvision.utils.make_grid(new_org),
                                      epoch)
                loss_task = self.criterion(output_task, target_task)
                loss_classification += loss_task
                loss_title = f"loss_{_FACTORS_IN_ORDER[i]}"
                self.train_metrics.update(loss_title,
                                          loss_task.item())
                for met in self.metric_ftns:
                    metric_title = f"{met.__name__}_{_FACTORS_IN_ORDER[i]}"
                    self.train_metrics.update(metric_title,
                                              met(output_task, target_task))


            loss_tot = loss_ret + loss_classification
            loss_tot.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            try:
                self.train_metrics.update('loss_retrieval', loss_ret.item())
            except AttributeError:
                print("Not enough data")
            self.train_metrics.update('accuracy_retrieval',
                                      accuracy_retrieval(output_ret, text_output))
            self.train_metrics.update('loss_classification', loss_classification.item())
            self.train_metrics.update('loss_tot', loss_tot.item())

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss tot: {:.6f}'.format(
                    epoch,
                    self._progress(batch_idx),
                    loss_tot.item()))
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
            
            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss ret: {:.6f}'.format(
                    epoch,
                    self._progress(batch_idx),
                    loss_ret.item()))
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
            
            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss classification: {:.6f}'.format(
                    epoch,
                    self._progress(batch_idx),
                    loss_classification.item()))
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            if batch_idx == self.len_epoch:
                break
        #add histograms for data distribution
        if epoch == 1:
            histogram_distribution(list_of_counters, 'train')
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_'+k : v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.model_text.eval()
        self.valid_metrics.reset()
        if epoch == 1:
            list_of_counters = []
            for i in range(0, 6):
                list_of_counters.append(Counter())
        with torch.no_grad():
            for batch_idx, (data, target_ret, target_init) in enumerate(self.valid_data_loader):
                data, target_ret = data.to(self.device), target_ret.to(self.device)
                target_init = target_init.to(self.device)
                text_output = self.model_text(target_ret.float())
                output_ret, output_init = self.model(data)
                no_tasks = len(target_init[0])
                loss_ret = self.criterion_ret(output_ret, text_output, 10)
                loss_classification = 0
                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
                for i in range(0, no_tasks):
                    output_task = output_init[i]
                    target_task = target_init[:, i]
                    if epoch == 1:
                        list_of_counters[i] += Counter(target_task.tolist())
                    new_org = add_margin(img_list=data[0:8, :, :],
                                         labels=target_task,
                                         predictions=output_task,
                                         margins=5,
                                         idx2label=self.data_loader.idx2label_init[i],
                                         font=self.font_type,
                                        )
                    self.writer.add_image(f"Image_val_marg_{_FACTORS_IN_ORDER[i]}_{epoch}",
                                          torchvision.utils.make_grid(new_org),
                                          epoch)
                    loss_task = self.criterion(output_task, target_task)
                    loss_classification += loss_task
                    loss_title = f"loss_{_FACTORS_IN_ORDER[i]}"
                    self.valid_metrics.update(loss_title,
                                              loss_task.item())
                    for met in self.metric_ftns:
                        metric_title = f"{met.__name__}_{_FACTORS_IN_ORDER[i]}"
                        self.valid_metrics.update(metric_title,
                                                  met(output_task, target_task))
                self.valid_metrics.update('loss_classification', loss_classification.item())
                loss_tot = loss_ret + loss_classification
                self.valid_metrics.update('loss_tot', loss_tot.item())
                # for met in self.metric_ftns:
                #     self.valid_metrics.update(met.__name__, met(output, target, no_tasks))
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))


                try:
                    self.valid_metrics.update('loss_retrieval', loss_ret.item())
                except AttributeError:
                    print("Not enough data")
                self.valid_metrics.update('accuracy_retrieval',
                                          accuracy_retrieval(output_ret, text_output))

                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        for name, p in self.model_text.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(self.len_epoch /
                            4)  # int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def get_lr(self):
        for param_group in self.optimizer.param_groups:
            return param_group['lr']

    def _train_epoch(self, epoch):
        fp16 = False
        gradient_accumulation_steps = 1

        self.logger.info("Current gradient_accumulation_steps: {}".format(
            gradient_accumulation_steps))
        self.logger.info("Current learning rate: {}".format(self.get_lr()))
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        trange = tqdm(enumerate(self.data_loader),
                      total=self.len_epoch,
                      desc="training")
        for batch_idx, batch in trange:
            data = batch["sentence"]
            target = batch["label"]

            if not isinstance(data, list):  # check if type is list
                data = data.to(self.device)
            if not isinstance(target, list):  # check if type is list
                target = target.to(self.device)

            output = self.model(data)

            if isinstance(output, list):
                output = torch.cat(output, dim=0).cuda()
            if isinstance(target, list):
                target = torch.cat(target, dim=0).cuda()

            if fp16:
                print(output, target)
                loss = self.criterion(output, target).half()
                print(loss)
            else:
                loss = self.criterion(output, target)

            if gradient_accumulation_steps > 1:
                loss = loss / gradient_accumulation_steps

            if fp16:
                with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            if (batch_idx + 1) % gradient_accumulation_steps == 0:
                if fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer), 1.0)
                else:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   1.0)

                self.optimizer.step()
                self.optimizer.zero_grad()

                self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
                self.train_metrics.update('loss', loss.item(), output.size(0))

                predict = (output >= 0.5)
                maxclass = torch.argmax(
                    output, dim=1
                )  # make sure every sentence predicted to at least one class
                for i in range(len(predict)):
                    predict[i][maxclass[i].item()] = 1
                predict = predict.type(torch.LongTensor).to(self.device)

                for met in self.metric_ftns:
                    self.train_metrics.update(met.__name__,
                                              met(predict, target),
                                              predict.size(0))
                '''
                if batch_idx % self.log_step == 0:
                    self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                        epoch,
                        self._progress(batch_idx),
                        loss.item()))
                    self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                '''
                if batch_idx == self.len_epoch:
                    break
                trange.set_postfix(loss=loss.item())
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, batch in enumerate(self.valid_data_loader):
                data = batch["sentence"]
                target = batch["label"]

                if not isinstance(data, list):
                    data = data.to(self.device)
                if not isinstance(target, list):
                    target = target.to(self.device)

                output = self.model(data)

                if isinstance(output, list):
                    output = torch.cat(output, dim=0).to(self.device)
                if isinstance(target, list):
                    target = torch.cat(target, dim=0).to(self.device)

                loss = self.criterion(output, target)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item(), output.size(0))

                predict = (output >= 0.5)
                maxclass = torch.argmax(
                    output, dim=1
                )  # make sure every sentence predicted to at least one class
                for i in range(len(predict)):
                    predict[i][maxclass[i].item()] = 1
                predict = predict.type(torch.LongTensor).to(self.device)

                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(predict, target),
                                              predict.size(0))
                #self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
class QuicknatLIDCTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 experiment=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config,
                         experiment)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker('loss', writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.metrics_sample_count = config['trainer']['metrics_sample_count']

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.model.enable_test_dropout()
        self.train_metrics.reset()
        for batch_idx, (data, target) in enumerate(self.data_loader):
            # shape data: [B x 1 x H x W]
            # shape target: [B x 4 x H x W]
            data, target = data.to(self.device), target.to(self.device)
            rand_idx = np.random.randint(0, 4)
            target = target[:, rand_idx, ...]

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()

            # self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.model.enable_test_dropout()
        self.valid_metrics.reset()

        with torch.no_grad():
            for batch_idx, (data,
                            targets) in enumerate(self.valid_data_loader):
                data, targets = data.to(self.device), targets.to(self.device)
                rand_idx = np.random.randint(0, 4)
                target = targets[:, rand_idx, ...]
                targets = targets.unsqueeze(2)

                # self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')

                # Loss
                output = self.model(data)
                loss = self.criterion(output, target)
                self.valid_metrics.update('loss', loss.item())

                # Sampling
                samples = self._sample(
                    self.model,
                    data)  # [BATCH_SIZE x SAMPLE_SIZE x NUM_CHANNELS x H x W]

                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(samples, targets))

                self._visualize_batch(batch_idx, samples, targets)

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _sample(self, model, data):
        num_samples = self.metrics_sample_count

        batch_size, num_channels, image_size = data.shape[0], 1, tuple(
            data.shape[2:])
        samples = torch.zeros((batch_size, num_samples, num_channels,
                               *image_size)).to(self.device)
        for i in range(num_samples):
            output = model(data)

            max_val, idx = torch.max(output, 1)
            sample = idx.unsqueeze(dim=1)
            samples[:, i, ...] = sample

        return samples

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _visualize_batch(self, batch_idx, samples, targets):
        gt_titles = [f'GT_{i}' for i in range(targets.shape[1])]
        s_titles = [f'S_{i}' for i in range(self.metrics_sample_count)]
        titles = gt_titles + s_titles

        vis_data = torch.cat((targets, samples), dim=1)
        img_metric_grid = visualization.make_image_metric_grid(
            vis_data, enable_helper_dots=True, titles=titles)

        self.writer.add_image(f'segmentations_batch_idx_{batch_idx}',
                              img_metric_grid.cpu())
class SegmentationTrainer(BaseTrainer):
    def __init__(self, model, criterion, metrics, optimizer, config, lr_scheduler=None):
        super().__init__(model, criterion, metrics, optimizer, config)
        self.lr_scheduler = lr_scheduler
        self.loss_name = 'supervised_loss'

        # Metrics
        # Train
        self.train_loss = MetricTracker(self.loss_name, self.writer)
        self.train_metrics = MetricTracker(*self.metric_names,
                                           self.writer)
        # Validation
        self.valid_loss = MetricTracker(self.loss_name, self.writer)
        self.valid_metrics = MetricTracker(*self.metric_names,
                                           self.writer)
        # Test
        self.test_loss = MetricTracker(self.loss_name, self.writer)
        self.test_metrics = MetricTracker(*self.metric_names,
                                          self.writer)

        if isinstance(self.model, nn.DataParallel):
            self.criterion = nn.DataParallel(self.criterion)

        # Resume checkpoint if path is available in config
        cp_path = self.config['trainer'].get('resume_path')
        if cp_path:
            super()._resume_checkpoint()

    def reset_scheduler(self):
        self.train_loss.reset()
        self.train_metrics.reset()
        self.valid_loss.reset()
        self.valid_metrics.reset()
        self.test_loss.reset()
        self.test_metrics.reset()
        # if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
        #     self.lr_scheduler.reset()

    def prepare_train_epoch(self, epoch):
        self.logger.info('EPOCH: {}'.format(epoch))
        self.reset_scheduler()

    def _train_epoch(self, epoch):
        self.model.train()
        self.prepare_train_epoch(epoch)
        for batch_idx, (data, target, image_name) in enumerate(self.train_data_loader):
            data, target = data.to(self.device), target.to(self.device)
            output = self.model(data)
            loss = self.criterion(output, target)
            # For debug model
            if torch.isnan(loss):
                super()._save_checkpoint(epoch)

            self.model.zero_grad()
            loss.backward()
            self.optimizer.step()

            # Update train loss, metrics
            self.train_loss.update(self.loss_name, loss.item())
            for metric in self.metrics:
                self.train_metrics.update(metric.__name__, metric(output, target), n=output.shape[0])

            if batch_idx % self.log_step == 0:
                self.log_for_step(epoch, batch_idx)

            if self.save_for_track and (batch_idx % self.save_for_track == 0):
                save_output(output, image_name, epoch, self.checkpoint_dir)

            if batch_idx == self.len_epoch:
                break

        log = self.train_loss.result()
        log.update(self.train_metrics.result())

        if self.do_validation and (epoch % self.do_validation_interval == 0):
            val_log = self._valid_epoch(epoch)
            log.update(val_log)

        # step lr scheduler
        if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
            self.lr_scheduler.step(self.valid_loss.avg(self.loss_name))

        return log

    @staticmethod
    def get_metric_message(metrics, metric_names):
        metrics_avg = [metrics.avg(name) for name in metric_names]
        message_metrics = ', '.join(['{}: {:.6f}'.format(x, y) for x, y in zip(metric_names, metrics_avg)])
        return message_metrics

    def log_for_step(self, epoch, batch_idx):
        message_loss = 'Train Epoch: {} [{}]/[{}] Dice Loss: {:.6f}'.format(epoch, batch_idx, self.len_epoch,
                                                                            self.train_loss.avg(self.loss_name))

        message_metrics = SegmentationTrainer.get_metric_message(self.train_metrics, self.metric_names)
        self.logger.info(message_loss)
        self.logger.info(message_metrics)

    def _valid_epoch(self, epoch, save_result=False, save_for_visual=False):
        self.model.eval()
        self.valid_loss.reset()
        self.valid_metrics.reset()
        self.logger.info('Validation: ')
        with torch.no_grad():
            for batch_idx, (data, target, image_name) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
                self.valid_loss.update(self.loss_name, loss.item())
                for metric in self.metrics:
                    self.valid_metrics.update(metric.__name__, metric(output, target), n=output.shape[0])

                if save_result:
                    save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1)

                if save_for_visual:
                    save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output'))
                    save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target'))

                if batch_idx % self.log_step == 0:
                    self.logger.debug('{}/{}'.format(batch_idx, len(self.valid_data_loader)))
                    self.logger.debug('{}: {}'.format(self.loss_name, self.valid_loss.avg(self.loss_name)))
                    self.logger.debug(SegmentationTrainer.get_metric_message(self.valid_metrics, self.metric_names))

        log = self.valid_loss.result()
        log.update(self.valid_metrics.result())
        val_log = {'val_{}'.format(k): v for k, v in log.items()}
        return val_log

    def _test_epoch(self, epoch, save_result=False, save_for_visual=False):
        self.model.eval()
        self.test_loss.reset()
        self.test_metrics.reset()
        self.logger.info('Test: ')
        with torch.no_grad():
            for batch_idx, (data, target, image_name) in enumerate(self.test_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)

                self.writer.set_step((epoch - 1) * len(self.test_data_loader) + batch_idx, 'test')
                self.test_loss.update(self.loss_name, loss.item())
                for metric in self.metrics:
                    self.test_metrics.update(metric.__name__, metric(output, target), n=output.shape[0])

                if save_result:
                    save_output(output, image_name, epoch, os.path.join(self.checkpoint_dir, 'tracker'), percent=1)

                if save_for_visual:
                    save_mask2image(output, image_name, os.path.join(self.checkpoint_dir, 'output'))
                    save_mask2image(target, image_name, os.path.join(self.checkpoint_dir, 'target'))

                if batch_idx % self.log_step == 0:
                    self.logger.debug('{}/{}'.format(batch_idx, len(self.test_data_loader)))
                    self.logger.debug('{}: {}'.format(self.loss_name, self.test_loss.avg(self.loss_name)))
                    self.logger.debug(SegmentationTrainer.get_metric_message(self.test_metrics, self.metric_names))

        log = self.test_loss.result()
        log.update(self.test_metrics.result())
        test_log = {'test_{}'.format(k): v for k, v in log.items()}
        return test_log
class TrainerVd(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
                 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))
        self.n_batches = data_loader.n_samples / data_loader.batch_size
        self.n_batches_valid = valid_data_loader.n_samples / valid_data_loader.batch_size

        self.train_metrics = MetricTracker('loss', 'kl_cost', 'pred_cost', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('loss', 'kl_cost', 'pred_cost', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

        self.keys.extend(['kl_cost', 'pred_cost'])
        if self.do_validation:
            keys_val = ['val_' + k for k in self.keys]
            for key in self.keys + keys_val:
                self.log[key] = []

    def _train_epoch(self, epoch, samples=10):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (data, target) in enumerate(self.data_loader):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()

            outputs = torch.zeros(data.shape[0], self.model.output_dim, samples).to(self.device)
            if samples == 1:
                out, tkl = self.model(data)
                mlpdw = self._compute_loss(out, target)
                Edkl = tkl / self.n_batches
                outputs[:, :, 0] = out

            elif samples > 1:
                mlpdw_cum = 0
                Edkl_cum = 0

                for i in range(samples):
                    out, tkl = self.model(data, sample=True)
                    mlpdw_i = self._compute_loss(out, target)
                    Edkl_i = tkl / self.n_batches
                    mlpdw_cum = mlpdw_cum + mlpdw_i
                    Edkl_cum = Edkl_cum + Edkl_i

                    outputs[:, :, i] = out

                mlpdw = mlpdw_cum / samples
                Edkl = Edkl_cum / samples

            mean = torch.mean(outputs, dim=2)
            loss = Edkl + mlpdw
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item(), n=len(target))
            self.train_metrics.update('kl_cost', Edkl.item(), n=len(target))
            self.train_metrics.update('pred_cost', mlpdw.item(), n=len(target))

            for met in self.metric_ftns:
                self._compute_metric(self.train_metrics, met, outputs, target)

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch,
                    self._progress(batch_idx),
                    loss.item()))
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_'+k : v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch, samples=100):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)

                loss = 0
                outputs = torch.zeros(data.shape[0], self.model.output_dim, samples).to(self.device)

                if samples == 1:
                    out, tkl = self.model(data)
                    mlpdw = self._compute_loss(out, target)
                    Edkl = tkl / self.n_batches_valid
                    outputs[:, :, 0] = out

                elif samples > 1:
                    mlpdw_cum = 0
                    Edkl_cum = 0

                    for i in range(samples):
                        out, tkl = self.model(data, sample=True)
                        mlpdw_i = self._compute_loss(out, target)
                        Edkl_i = tkl / self.n_batches_valid
                        mlpdw_cum = mlpdw_cum + mlpdw_i
                        Edkl_cum = Edkl_cum + Edkl_i

                        outputs[:, :, i] = out

                    mlpdw = mlpdw_cum / samples
                    Edkl = Edkl_cum / samples

                mean = torch.mean(outputs, dim=2)
                loss = Edkl + mlpdw

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss.item(), n=len(target))
                self.valid_metrics.update('kl_cost', Edkl.item(), n=len(target))
                self.valid_metrics.update('pred_cost', mlpdw.item(), n=len(target))

                for met in self.metric_ftns:
                    self._compute_metric(self.valid_metrics, met, outputs, target)
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _compute_loss(self, output, target):
        if self.model.regression_type == 'h**o':
            loss = self.criterion(output, target, self.model.log_noise.exp(), self.model.output_dim)
        elif self.model.regression_type == 'hetero':
            loss = self.criterion(output, target, self.model.output_dim/2)
        else:
            loss = self.criterion(output, target)
        return loss

    def _compute_metric(self, metrics, met, output, target, type="VD"):
        if self.model.regression_type == 'h**o':
            metrics.update(met.__name__, met([output, self.model.log_noise.exp()], target,type))
        else:
            metrics.update(met.__name__, met(output, target, type))
Beispiel #22
0
class OPUSMultitaskTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 experiment=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config,
                         experiment)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

        for param_group in optimizer.param_groups:
            lr = param_group['lr']

        model.cross1ss = torch.nn.Parameter(data=model.cross1ss.to(
            self.device),
                                            requires_grad=True)
        model.cross1sc = torch.nn.Parameter(data=model.cross1sc.to(
            self.device),
                                            requires_grad=True)
        model.cross1cc = torch.nn.Parameter(data=model.cross1cc.to(
            self.device),
                                            requires_grad=True)
        model.cross1cs = torch.nn.Parameter(data=model.cross1cs.to(
            self.device),
                                            requires_grad=True)

        model.cross2ss = torch.nn.Parameter(data=model.cross2ss.to(
            self.device),
                                            requires_grad=True)
        model.cross2sc = torch.nn.Parameter(data=model.cross2sc.to(
            self.device),
                                            requires_grad=True)
        model.cross2cc = torch.nn.Parameter(data=model.cross2cc.to(
            self.device),
                                            requires_grad=True)
        model.cross2cs = torch.nn.Parameter(data=model.cross2cs.to(
            self.device),
                                            requires_grad=True)

        model.cross3ss = torch.nn.Parameter(data=model.cross3ss.to(
            self.device),
                                            requires_grad=True)
        model.cross3sc = torch.nn.Parameter(data=model.cross3sc.to(
            self.device),
                                            requires_grad=True)
        model.cross3cc = torch.nn.Parameter(data=model.cross3cc.to(
            self.device),
                                            requires_grad=True)
        model.cross3cs = torch.nn.Parameter(data=model.cross3cs.to(
            self.device),
                                            requires_grad=True)

        model.crossbss = torch.nn.Parameter(data=model.crossbss.to(
            self.device),
                                            requires_grad=True)
        model.crossbsc = torch.nn.Parameter(data=model.crossbsc.to(
            self.device),
                                            requires_grad=True)
        model.crossbcc = torch.nn.Parameter(data=model.crossbcc.to(
            self.device),
                                            requires_grad=True)
        model.crossbcs = torch.nn.Parameter(data=model.crossbcs.to(
            self.device),
                                            requires_grad=True)

        # Hack: Set a different learning rate for the cross-stitch parameters
        optimizer.add_param_group({
            'params': [
                model.cross1ss, model.cross1sc, model.cross1cs, model.cross1cc,
                model.cross2ss, model.cross2sc, model.cross2cs, model.cross2cc,
                model.cross3ss, model.cross3sc, model.cross3cs, model.cross3cc,
                model.crossbss, model.crossbsc, model.crossbcs, model.crossbcc
            ],
            'lr':
            lr * 250
        })

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (data, target_seg,
                        target_class) in enumerate(self.data_loader):
            data, target_seg, target_class = data.to(
                self.device), target_seg.to(self.device), target_class.to(
                    self.device)

            self.optimizer.zero_grad()
            output_seg, output_class = self.model(data)
            loss = self.criterion((output_seg, output_class), target_seg,
                                  target_class, epoch)
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                if met.__name__ == "accuracy":
                    self.train_metrics.update(met.__name__,
                                              met(output_class, target_class))
                else:
                    self.train_metrics.update(met.__name__,
                                              met(output_seg, target_seg))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

                self._visualize_input(data.cpu())

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target_seg,
                            target_class) in enumerate(self.valid_data_loader):
                data, target_seg, target_class = data.to(
                    self.device), target_seg.to(self.device), target_class.to(
                        self.device)

                output_seg, output_class = self.model(data)
                loss = self.criterion((output_seg, output_class), target_seg,
                                      target_class, epoch)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    if met.__name__ == "accuracy":
                        self.valid_metrics.update(
                            met.__name__, met(output_class, target_class))
                    else:
                        self.valid_metrics.update(met.__name__,
                                                  met(output_seg, target_seg))

                data_cpu = data.cpu()
                self._visualize_input(data_cpu)
                self._visualize_prediction(data_cpu, output_seg.cpu(),
                                           target_seg.cpu())

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _visualize_input(self, input):
        """format and display input data on tensorboard"""
        self.writer.add_image(
            'input', make_grid(input[0, 0, :, :], nrow=8, normalize=True))

    def _visualize_prediction(self, input, output, target):
        """format and display output and target data on tensorboard"""
        out_b1 = binary(output)
        out_b1 = impose_labels_on_image(input[0, 0, :, :], target[0, :, :],
                                        out_b1[0, 1, :, :])
        self.writer.add_image('output',
                              make_grid(out_b1, nrow=8, normalize=False))
Beispiel #23
0
class Trainer(BaseTrainer):
    def __init__(self,
                 model,
                 criterion,
                 metric_fns,
                 optimizer,
                 config,
                 data_loader,
                 feature_index,
                 cell_neighbor_set,
                 drug_neighbor_set,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_fns, optimizer, config)
        self.config = config
        # for data
        self.data_loader = data_loader
        self.cell_neighbor_set = cell_neighbor_set
        self.drug_neighbor_set = drug_neighbor_set
        self.feature_index = feature_index

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch

        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader

        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss', *[m.__name__ for m in self.metric_fns], writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss', *[m.__name__ for m in self.metric_fns], writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (data, target) in enumerate(self.data_loader):
            target = target.to(self.device)
            output, emb_loss = self.model(*self._get_feed_dict(data))
            loss = self.criterion(output, target.squeeze()) + emb_loss

            self.optimizer.zero_grad()
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            with torch.no_grad():
                y_pred = torch.sigmoid(output)
                y_pred = y_pred.cpu().detach().numpy()
                y_true = target.cpu().detach().numpy()
                for met in self.metric_fns:
                    self.train_metrics.update(met.__name__,
                                              met(y_pred, y_true))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()
        log['train'] = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})
            log['validation'] = {'val_' + k: v for k, v in val_log.items()}

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                target = target.to(self.device)
                output, emb_loss = self.model(*self._get_feed_dict(data))
                loss = self.criterion(output, target.squeeze()) + emb_loss

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                y_pred = torch.sigmoid(output)
                y_pred = y_pred.cpu().detach().numpy()
                y_true = target.cpu().detach().numpy()
                for met in self.metric_fns:
                    self.valid_metrics.update(met.__name__,
                                              met(y_pred, y_true))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def test(self):
        self.model.eval()
        total_loss = 0.0
        total_metrics = torch.zeros(len(self.metric_fns))
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                target = target.to(self.device)
                output, emb_loss = self.model(*self._get_feed_dict(data))
                loss = self.criterion(output, target.squeeze()) + emb_loss

                batch_size = data.shape[0]
                total_loss += loss.item() * batch_size

                y_pred = torch.sigmoid(output)
                y_pred = y_pred.cpu().detach().numpy()
                y_true = target.cpu().detach().numpy()
                for i, metric in enumerate(self.metric_fns):
                    total_metrics[i] += metric(y_pred, y_true) * batch_size

        test_output = {
            'n_samples': len(self.test_data_loader.sampler),
            'total_loss': total_loss,
            'total_metrics': total_metrics
        }

        return test_output

    def get_save(self, save_files):
        result = dict()
        for key, value in save_files.items():
            if type(value) == dict:
                temp = dict()
                for k, v in value.items():
                    temp[k] = v.cpu().detach().numpy()
            else:
                temp = value.cpu().detach().numpy()
            result[key] = temp
        return result

    def _get_feed_dict(self, data):
        # [batch_size]
        cells = data[:, self.feature_index['cell']]
        drugs1 = data[:, self.feature_index['drug1']]
        drugs2 = data[:, self.feature_index['drug2']]
        cells_neighbors, drugs1_neighbors, drugs2_neighbors = [], [], []
        for hop in range(self.model.n_hop):
            cells_neighbors.append(torch.LongTensor([self.cell_neighbor_set[c][hop] \
                                                       for c in cells.numpy()]).to(self.device))
            drugs1_neighbors.append(torch.LongTensor([self.drug_neighbor_set[d][hop] \
                                                          for d in drugs1.numpy()]).to(self.device))
            drugs2_neighbors.append(torch.LongTensor([self.drug_neighbor_set[d][hop] \
                                                          for d in drugs2.numpy()]).to(self.device))

        return cells.to(self.device), drugs1.to(self.device), drugs2.to(self.device), \
               cells_neighbors, drugs1_neighbors, drugs2_neighbors

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #24
0
class Learning(object):
    def __init__(self,
            model,
            criterion,
            optimizer,
            scheduler,
            metric_ftns,
            device,
            num_epoch,
            grad_clipping,
            grad_accumulation_steps,
            early_stopping,
            validation_frequency,
            tensorboard,
            checkpoint_dir,
            resume_path):
        self.device, device_ids = self._prepare_device(device)
        # self.model = model.to(self.device)
        
        self.start_epoch = 1
        if resume_path is not None:
            self._resume_checkpoint(resume_path)
        if len(device_ids) > 1:
            # self.model = torch.nn.DataParallel(model, device_ids=device_ids)
            self.model = torch.nn.DataParallel(model)
            # cudnn.benchmark = True
        self.model = model.cuda()
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.num_epoch = num_epoch 
        self.scheduler = scheduler
        self.grad_clipping = grad_clipping
        self.grad_accumulation_steps = grad_accumulation_steps
        self.early_stopping = early_stopping
        self.validation_frequency =validation_frequency
        self.checkpoint_dir = checkpoint_dir
        self.best_epoch = 1
        self.best_score = 0
        self.writer = TensorboardWriter(os.path.join(checkpoint_dir, 'tensorboard'), tensorboard)
        self.train_metrics = MetricTracker('loss', writer = self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer = self.writer)
        
    def train(self, train_dataloader):
        score = 0
        for epoch in range(self.start_epoch, self.num_epoch+1):
            print("{} epoch: \t start training....".format(epoch))
            start = time.time()
            train_result  = self._train_epoch(epoch, train_dataloader)
            train_result.update({'time': time.time()-start})
            
            for key, value in train_result.items():
                print('    {:15s}: {}'.format(str(key), value))

            # if (epoch+1) % self.validation_frequency!=0:
            #     print("skip validation....")
            #     continue
            # print('{} epoch: \t start validation....'.format(epoch))
            # start = time.time()
            # valid_result = self._valid_epoch(epoch, valid_dataloader)
            # valid_result.update({'time': time.time() - start})
            
            # for key, value in valid_result.items():
            #     if 'score' in key:
            #         score = value 
            #     print('   {:15s}: {}'.format(str(key), value))
            score+=0.001
            self.post_processing(score, epoch)
            if epoch - self.best_epoch > self.early_stopping:
                print('WARNING: EARLY STOPPING')
                break
    def _train_epoch(self, epoch, data_loader):
        self.model.train()
        self.optimizer.zero_grad()
        self.train_metrics.reset()
        for idx, (data, target) in enumerate(data_loader):
            data = Variable(data.cuda())
            target = [ann.to(self.device) for ann in target]
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.writer.set_step((epoch - 1) * len(data_loader) + idx)
            self.train_metrics.update('loss', loss.item())
            if (idx+1) % self.grad_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clipping)
                self.optimizer.step()
                self.optimizer.zero_grad()
            if (idx+1) % int(np.sqrt(len(data_loader))) == 0:
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
        return self.train_metrics.result()
    def _valid_epoch(self, epoch, data_loader):
        self.valid_metrics.reset()
        self.model.eval()
        with torch.no_grad():
            for idx, (data, target) in enumerate(data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.writer.set_step((epoch - 1) * len(data_loader) + idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(output, target))
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
        
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        
        return self.valid_metrics.result()
    def post_processing(self, score, epoch):
        best = False
        if score > self.best_score:
            self.best_score = score 
            self.best_epoch = epoch 
            best = True
            print("best model: {} epoch - {:.5}".format(epoch, score))
        self._save_checkpoint(epoch = epoch, save_best = best)
        
        if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau':
            self.scheduler.step(score)
        else:
            self.scheduler.step()
    
    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints
        :param epoch: current epoch number
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.get_state_dict(self.model),
            'best_score': self.best_score
        }
        filename = os.path.join(self.checkpoint_dir, 'checkpoint_epoch{}.pth'.format(epoch))
        torch.save(state, filename)
        print("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
            torch.save(state, best_path)
            print("Saving current best: model_best.pth ...")
    @staticmethod
    def get_state_dict(model):
        if type(model) == torch.nn.DataParallel:
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        return state_dict
    
    def _resume_checkpoint(self, resume_path):
        resume_path = str(resume_path)
        print("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)
        self.start_epoch = checkpoint['epoch'] + 1
        self.best_epoch = checkpoint['epoch']
        self.best_score = checkpoint['best_score']
        self.model.load_state_dict(checkpoint['state_dict'])

        print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
    
    @staticmethod
    def _prepare_device(device):
        n_gpu_use = len(device)
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        list_ids = device
        device = torch.device('cuda:{}'.format(device[0]) if n_gpu_use > 0 else 'cpu')
        
        return device, list_ids
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self, model, criterion, metric_ftns, optimizer, config, data_loader,
                 valid_data_loader=None, lr_scheduler=None, len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler

        if self.config['log_step'] is not None:
            self.log_step = self.config['log_step']
        else:
            self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()

        for batch_idx, data in enumerate(self.data_loader):
            data = overlap_objects_from_batch(data, self.config['n_objects'])
            target = data # Is data a variable?
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data, epoch)
            loss, loss_particles = self.criterion(output, target,
                                                  epoch_iter=(epoch, (epoch + 1)*batch_idx), lambd=self.config["trainer"]["lambd"])
            loss = loss.mean()

            # Note: from space implementation
            # optimizer_fg.zero_grad()
            # optimizer_bg.zero_grad()
            # loss.backward()
            # if cfg.train.clip_norm:
            #     clip_grad_norm_(model.parameters(), cfg.train.clip_norm)

            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1000)
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output, target))

            if batch_idx % self.log_step == 0:
                loss_particles_str = " ".join([key + ': {:.6f}, '.format(loss_particles[key].item()) for key in loss_particles])

                self.logger.debug('Train Epoch: {} {} '.format(epoch, self._progress(batch_idx)) + loss_particles_str + 'Loss: {:.6f}'.format(
                    loss.item()))

                self._show(data, output)

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_'+k : v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step(loss)
            # self.lr_scheduler.step() #Note: If it doesn't require argument.
            self.writer.add_scalar('LR', self.optimizer.param_groups[0]['lr'])
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, data in enumerate(self.valid_data_loader):
                data = overlap_objects_from_batch(data, self.config['n_objects'])
                target = data  # Is data a variable?
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data, epoch=epoch)
                loss, loss_particles = self.criterion(output, target,
                                                      epoch_iter=(epoch, (epoch + 1)*batch_idx),
                                                      lambd=self.config["trainer"]["lambd"])

                self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(output, target))

                self._show(data, output, train=False)

        # add histogram of model parameters to the tensorboard
        # for name, p in self.model.named_parameters():
        #     self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _show(self, data, output, train=True):
        g_plot = plot_representation(output[2][:,:output[0].shape[1]].cpu())
        g_plot_pred = plot_representation(output[2][:,output[0].shape[1]:].cpu())
        A_plot = plot_matrix(output[4])
        if output[5] is not None:
            B_plot = plot_matrix(output[5])
            self.writer.add_image('B', make_grid(B_plot, nrow=1, normalize=False))
        if output[6] is not None:
            u_plot = plot_representation(output[6][:, :output[6].shape[1]].cpu())
            self.writer.add_image('u', make_grid(to_tensor(u_plot), nrow=1, normalize=False))
        # if output[10] is not None: # TODO: Ara el torno a posar
        #     # print(output[10][0].max(), output[-1][0].min())
        #     shape = output[10][0].shape
        #     self.writer.add_image('objects', make_grid(output[10][0].permute(1, 2, 0, 3, 4).reshape(*shape[1:-2], -1, shape[-1]).cpu(), nrow=output[0].shape[1], normalize=True))
        self.writer.add_image('A', make_grid(A_plot, nrow=1, normalize=False))
        self.writer.add_image('g_repr', make_grid(to_tensor(g_plot), nrow=1, normalize=False))
        self.writer.add_image('g_repr_pred', make_grid(to_tensor(g_plot_pred), nrow=1, normalize=False))
        self.writer.add_image('input', make_grid(data[0].cpu(), nrow=data.shape[1], normalize=True))
        self.writer.add_image('output_0rec', make_grid(output[0][0].cpu(), nrow=output[0].shape[1], normalize=True))
        self.writer.add_image('output_1pred', make_grid(output[1][0].cpu(), nrow=output[1].shape[1], normalize=True))
Beispiel #26
0
class MEmoRTrainer(BaseTrainer):
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = config.init_obj('lr_scheduler',
                                            torch.optim.lr_scheduler,
                                            self.optimizer)
        self.log_step = 200
        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, data in enumerate(self.data_loader):
            target, U_v, U_a, U_t, U_p, M_v, M_a, M_t, target_loc, umask, seg_len, n_c = [
                d.to(self.device) for d in data
            ]

            self.optimizer.zero_grad()
            seq_lengths = [(umask[j] == 1).nonzero().tolist()[-1][0] + 1
                           for j in range(len(umask))]

            output = self.model(U_v, U_a, U_t, U_p, M_v, M_a, M_t, seq_lengths,
                                target_loc, seg_len, n_c)
            assert output.shape[0] == target.shape[0]
            target = target.squeeze(1)
            loss = self.criterion(output, target)
            loss.backward()

            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())

            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output, target))

            if batch_idx % self.log_step == 0:
                self.logger.debug(
                    'Train Epoch: {} {} Loss: {:.6f} Time:{}'.format(
                        epoch, self._progress(batch_idx), loss.item(),
                        datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S')))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        outputs, targets = [], []
        with torch.no_grad():
            for batch_idx, data in enumerate(self.valid_data_loader):
                target, U_v, U_a, U_t, U_p, M_v, M_a, M_t, target_loc, umask, seg_len, n_c = [
                    d.to(self.device) for d in data
                ]
                seq_lengths = [(umask[j] == 1).nonzero().tolist()[-1][0] + 1
                               for j in range(len(umask))]

                output = self.model(U_v, U_a, U_t, U_p, M_v, M_a, M_t,
                                    seq_lengths, target_loc, seg_len, n_c)
                target = target.squeeze(1)
                loss = self.criterion(output, target)

                outputs.append(output.detach())
                targets.append(target.detach())

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())

            outputs = torch.cat(outputs, dim=0)
            targets = torch.cat(targets, dim=0)
            for met in self.metric_ftns:
                self.valid_metrics.update(met.__name__, met(outputs, targets))

        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
class ResNetTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 experiment=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config,
                         experiment)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

        self.best_val_accuracy = 0

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        train_confusion_matrix = torch.zeros(3, 3, dtype=torch.long)
        print('train epoch: ', epoch)
        for batch_idx, (data, label, target_class,
                        idx) in enumerate(self.data_loader):
            print('train batch, item: ', batch_idx, ', ', idx)
            data, target_class = data.to(self.device), target_class.to(
                self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target_class)
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__,
                                          met(output, target_class))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))

                self._visualize_input(data.cpu())

            p_cls = torch.argmax(output, dim=1)
            for i, t_cl in enumerate(target_class):
                train_confusion_matrix[p_cls[i], t_cl] += 1

            if batch_idx == self.len_epoch:
                break

        print('train confusion matrix:')
        print(train_confusion_matrix)
        self._visualize_prediction(train_confusion_matrix)
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            val_confusion_matrix = torch.zeros(3, 3, dtype=torch.long)
            print('val epoch: ', epoch)
            for batch_idx, (data, label, target_class,
                            idx) in enumerate(self.valid_data_loader):
                print('val batch, item: ', batch_idx, ', ', idx)
                data, target_class = data.to(self.device), target_class.to(
                    self.device)

                output = self.model(data)
                loss = self.criterion(output, target_class)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output, target_class))

                self._visualize_input(data.cpu())
                #prediction = torch.argmax(output)
                #self.logger.debug('val class prediction, actual: {}, {}'.format(prediction, target_class))

                p_cls = torch.argmax(output, dim=1)
                for i, t_cl in enumerate(target_class):
                    val_confusion_matrix[p_cls[i], t_cl] += 1

            print('val confusion matrix:')
            print(val_confusion_matrix)
            self._visualize_prediction(val_confusion_matrix)

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')

        val_log = self.valid_metrics.result()

        # TODO: Super hacky way to display best val dice score. Better way possible?
        self.writer.set_step(
            (epoch - 1) * len(self.valid_data_loader) + batch_idx,
            'best_valid')
        val_scores = {k: v for k, v in val_log.items()}
        current_val_accuracy = val_scores['accuracy']

        if current_val_accuracy > self.best_val_accuracy:
            self.best_val_accuracy = current_val_accuracy
            self.valid_metrics.update('accuracy', self.best_val_accuracy)

        return val_log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _visualize_input(self, input):
        """format and display input data on tensorboard"""
        self.writer.add_image(
            'input', make_grid(input[0, 0, :, :], nrow=8, normalize=True))

    def _visualize_prediction(self, matrix):
        """format and display output and target data on tensorboard"""
        out = draw_confusion_matrix(matrix)
        self.writer.add_image('output', make_grid(out, nrow=8,
                                                  normalize=False))
Beispiel #28
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        start_epoch = time.time()
        self.model.train()
        self.train_metrics.reset()
        # print("Learning rate:", self.lr_scheduler.get_lr())
        for batch_idx, (inputs, labels) in enumerate(self.data_loader):
            # debugging
            # print('Classes: ', torch.unique(labels))
            face, context = inputs['face'].to(
                self.device), inputs['context'].to(self.device)
            labels = labels.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(face, context)
            loss = self.criterion(output, labels)
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output, labels))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))
                self.writer.add_image(
                    'face', make_grid(face.cpu(), nrow=4, normalize=True))
                self.writer.add_image(
                    'context', make_grid(context.cpu(), nrow=2,
                                         normalize=True))

                for name, p in self.model.named_parameters():
                    if p.requires_grad and p.grad is not None:
                        self.writer.add_histogram('grad_' + name,
                                                  p.grad,
                                                  bins='auto')

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        time_elapsed = time.time() - start_epoch
        print('Epoch completes in {:.0f}m {:.0f}s'.format(
            time_elapsed // 60, time_elapsed % 60))

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (inputs,
                            labels) in enumerate(self.valid_data_loader):
                face, context = inputs['face'].to(
                    self.device), inputs['context'].to(self.device)
                labels = labels.to(self.device)

                output = self.model(face, context)
                loss = self.criterion(output, labels)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output, labels))
                self.writer.add_image(
                    'face', make_grid(face.cpu(), nrow=4, normalize=True))
                self.writer.add_image(
                    'context', make_grid(context.cpu(), nrow=2,
                                         normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #29
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_ftns,
                 optimizer,
                 config,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_ftns, optimizer, config)
        self.config = config
        self.data_loader = data_loader
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        self.model.train()
        self.train_metrics.reset()
        for batch_idx, (data, target) in enumerate(self.data_loader):
            data, target = data.to(self.device), target.to(self.device)

            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output, target))

            if batch_idx % self.log_step == 0:
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                    epoch, self._progress(batch_idx), loss.item()))
                self.writer.add_image(
                    'input', make_grid(data.cpu(), nrow=8, normalize=True))

            if batch_idx == self.len_epoch:
                break
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()
        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self.model.eval()
        self.valid_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)

                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output, target))
                self.writer.add_image(
                    'input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Beispiel #30
0
class LayerwiseTrainer(BaseTrainer):
    """
    Trainer
    """
    def __init__(self,
                 model: DepthwiseStudent,
                 criterions,
                 metric_ftns,
                 optimizer,
                 config,
                 train_data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 weight_scheduler=None):
        super().__init__(model, None, metric_ftns, optimizer, config)
        self.config = config
        self.train_data_loader = train_data_loader
        self.valid_data_loader = valid_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_validation_interval = self.config['trainer'][
            'do_validation_interval']
        self.lr_scheduler = lr_scheduler
        self.weight_scheduler = weight_scheduler
        self.log_step = config['trainer']['log_step']
        if "len_epoch" in self.config['trainer']:
            # iteration-based training
            self.train_data_loader = inf_loop(train_data_loader)
            self.len_epoch = self.config['trainer']['len_epoch']
        else:
            # epoch-based training
            self.len_epoch = len(self.train_data_loader)

        # Metrics
        # Train
        self.train_metrics = MetricTracker(
            'loss',
            'supervised_loss',
            'kd_loss',
            'hint_loss',
            'teacher_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.train_iou_metrics = CityscapesMetricTracker(writer=self.writer)
        self.train_teacher_iou_metrics = CityscapesMetricTracker(
            writer=self.writer)
        # Valid
        self.valid_metrics = MetricTracker(
            'loss',
            'supervised_loss',
            'kd_loss',
            'hint_loss',
            'teacher_loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        self.valid_iou_metrics = CityscapesMetricTracker(writer=self.writer)
        # Test
        self.test_metrics = MetricTracker(
            'loss',
            'supervised_loss',
            'kd_loss',
            'hint_loss',
            'teacher_loss',
            *[m.__name__ for m in self.metric_ftns],
            *['teacher_' + m.__name__ for m in self.metric_ftns],
            writer=self.writer,
        )
        self.test_iou_metrics = CityscapesMetricTracker(writer=self.writer)

        # Tracker for early stop if val miou doesn't increase
        self.val_iou_tracker = EarlyStopTracker('best', 'max', 0.01, 'rel')

        # Only used list of criterions and remove the unused property
        self.criterions = criterions
        self.criterions = nn.ModuleList(self.criterions).to(self.device)
        if isinstance(self.model, nn.DataParallel):
            self.criterions = nn.DataParallel(self.criterions)
        del self.criterion

        # Resume checkpoint if path is available in config
        if 'resume_path' in self.config['trainer']:
            self.resume(self.config['trainer']['resume_path'])

    def prepare_train_epoch(self, epoch, config=None):
        """
        Prepare before training an epoch i.e. prune new layer, unfreeze some layers, create new optimizer ....
        :param epoch:  int - indicate which epoch the trainer's in
        :param config: a config object that contain pruning_plan, hint, unfreeze information
        :return: 
        """
        # if the config is not set (training normaly, then set config to current trainer config)
        # if the config is set (in case you're resuming a checkpoint) then use saved config to replace
        #    layers in student so that it would have identical archecture with saved checkpoint
        if config is None:
            config = self.config
        # reset_scheduler
        self.reset_scheduler()
        # there isn't any layer that would be replaced or unfreeze or set as hint then unfreeze
        # the whole network
        if (epoch == 1) and ((len(config['pruning']['pruning_plan']) +
                              len(config['pruning']['hint']) +
                              len(config['pruning']['unfreeze'])) == 0):
            self.logger.debug(
                'Train a student with identical architecture with teacher')
            # unfreeze
            for param in self.model.student.parameters():
                param.requires_grad = True
            # debug
            self.logger.info(self.model.dump_trainable_params())
            # create optimizer for the network
            self.create_new_optimizer()
            # ignore all below stuff
            return

        # Check if there is any layer that would any update in current epoch
        # list of epochs that would have an update on student networks
        epochs = list(
            map(
                lambda x: x['epoch'], config['pruning']['pruning_plan'] +
                config['pruning']['hint'] + config['pruning']['unfreeze']))
        # if there isn't any update then simply return
        if epoch not in epochs:
            self.logger.info('EPOCH: ' + str(epoch))
            self.logger.info('There is no update ...')
            return

        # layers that would be replaced by depthwise separable conv
        replaced_layers = list(
            filter(lambda x: x['epoch'] == epoch,
                   config['pruning']['pruning_plan']))
        # layers which outputs will be used as loss
        hint_layers = list(
            map(
                lambda x: x['name'],
                filter(lambda x: x['epoch'] == epoch,
                       config['pruning']['hint'])))
        # layers that would be trained in this epoch
        unfreeze_layers = list(
            map(
                lambda x: x['name'],
                filter(lambda x: x['epoch'] == epoch,
                       config['pruning']['unfreeze'])))
        self.logger.info('EPOCH: ' + str(epoch))
        self.logger.info('Replaced layers: ' + str(replaced_layers))
        self.logger.info('Hint layers: ' + str(hint_layers))
        self.logger.info('Unfreeze layers: ' + str(unfreeze_layers))
        # Avoid error when loading deprecate checkpoint which don't have 'args' in config.pruning
        if 'args' in config['pruning']:
            kwargs = config['pruning']['args']
        else:
            self.logger.warning('Using deprecate checkpoint...')
            kwargs = config['pruning']['pruner']

        self.model.replace(
            replaced_layers,
            **kwargs)  # replace those layers with depthwise separable conv
        self.model.register_hint_layers(
            hint_layers
        )  # assign which layers output would be used as hint loss
        self.model.unfreeze(unfreeze_layers)  # unfreeze chosen layers

        if epoch == 1:
            self.create_new_optimizer(
            )  # create new optimizer to remove the effect of momentum
        else:
            self.update_optimizer(
                list(
                    filter(lambda x: x['epoch'] == epoch,
                           config['pruning']['unfreeze'])))

        self.logger.info(self.model.dump_trainable_params())
        self.logger.info(self.model.dump_student_teacher_blocks_info())

    def update_optimizer(self, unfreeze_config):
        """
        Update param groups for optimizer with unfreezed layers of this epoch
        :param unfreeze_config - list of arg. Each arg is the dictionary with following format:
            {'name': 'layer1', 'epoch':1, 'lr'(optional): 0.01}
        return: 
        """
        if len(unfreeze_config) > 0:
            self.logger.debug('Updating optimizer for new layer')
        for config in unfreeze_config:
            layer_name = config['name']  # layer that will be unfreezed
            self.logger.debug(
                'Add parameters of layer: {} to optimizer'.format(layer_name))

            layer = self.model.get_block(
                layer_name,
                self.model.student)  # actual layer i.e. nn.Module obj
            optimizer_arg = self.config['optimizer'][
                'args']  # default args for optimizer

            # we can also specify layerwise learning !
            if "lr" in config:
                optimizer_arg['lr'] = config['lr']
            # add unfreezed layer's parameters to optimizer
            self.optimizer.add_param_group({
                'params': layer.parameters(),
                **optimizer_arg
            })

    def create_new_optimizer(self):
        """
        Create new optimizer if trainer is in epoch 1 otherwise just run update optimizer
        """
        # Create new optimizer
        self.logger.debug('Creating new optimizer ...')
        self.optimizer = self.config.init_obj(
            'optimizer', optim_module,
            list(
                filter(lambda x: x.requires_grad,
                       self.model.student.parameters())))
        self.lr_scheduler = self.config.init_obj('lr_scheduler',
                                                 optim_module.lr_scheduler,
                                                 self.optimizer)

    def reset_scheduler(self):
        """
        reset all schedulers, metrics, trackers, etc when unfreeze new layer
        :return:
        """
        self.weight_scheduler.reset()  # weight between loss
        self.val_iou_tracker.reset()  # verify val iou would increase each time
        self.train_metrics.reset()  # metrics for loss,... in training phase
        self.valid_metrics.reset()  # metrics for loss,... in validating phase
        self.train_iou_metrics.reset()  # train iou of student
        self.valid_iou_metrics.reset()  # val iou of student
        self.train_teacher_iou_metrics.reset()  # train iou of teacher
        if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
            self.lr_scheduler.reset()

    def _train_epoch(self, epoch):
        """
        Training logic for 1 epoch
        """
        # Prepare the network i.e. unfreezed new layers, replaced new layer with depthwise separable conv, ...
        self.prepare_train_epoch(epoch)

        # reset
        # FIXME:
        # as the teacher network contain batchnorm layer and our resources are limited to train with
        # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having
        # small batch size
        # self.model.train()
        self.train_iou_metrics.reset()
        self.train_teacher_iou_metrics.reset()
        self._clean_cache()

        for batch_idx, (data, target) in enumerate(self.train_data_loader):
            data, target = data.to(self.device), target.to(self.device)

            output_st, output_tc = self.model(data)

            supervised_loss = self.criterions[0](
                output_st, target) / self.accumulation_steps
            kd_loss = self.criterions[1](output_st,
                                         output_tc) / self.accumulation_steps
            teacher_loss = self.criterions[0](output_tc,
                                              target)  # for comparision

            hint_loss = reduce(
                lambda acc, elem: acc + self.criterions[2](elem[0], elem[1]),
                zip(self.model.student_hidden_outputs,
                    self.model.teacher_hidden_outputs),
                0) / self.accumulation_steps

            # Only use hint loss
            loss = hint_loss
            loss.backward()

            if batch_idx % self.accumulation_steps == 0:
                self.optimizer.step()
                self.optimizer.zero_grad()

            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)

            # update metrics
            self.train_metrics.update('loss',
                                      loss.item() * self.accumulation_steps)
            self.train_metrics.update(
                'supervised_loss',
                supervised_loss.item() * self.accumulation_steps)
            self.train_metrics.update('kd_loss',
                                      kd_loss.item() * self.accumulation_steps)
            self.train_metrics.update(
                'hint_loss',
                hint_loss.item() * self.accumulation_steps)
            self.train_metrics.update('teacher_loss', teacher_loss.item())
            self.train_iou_metrics.update(output_st.detach().cpu(),
                                          target.cpu())
            self.train_teacher_iou_metrics.update(output_tc.cpu(),
                                                  target.cpu())

            for met in self.metric_ftns:
                self.train_metrics.update(met.__name__, met(output_st, target))

            if batch_idx % self.log_step == 0:
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                # st_masks = visualize.viz_pred_cityscapes(output_st)
                # tc_masks = visualize.viz_pred_cityscapes(output_tc)
                # self.writer.add_image('st_pred', make_grid(st_masks, nrow=8, normalize=False))
                # self.writer.add_image('tc_pred', make_grid(tc_masks, nrow=8, normalize=False))
                self.logger.info(
                    'Train Epoch: {} [{}]/[{}] Loss: {:.6f} mIoU: {:.6f} Teacher mIoU: {:.6f} Supervised Loss: {:.6f} '
                    'Knowledge Distillation loss: '
                    '{:.6f} Hint Loss: {:.6f} Teacher Loss: {:.6f}'.format(
                        epoch,
                        batch_idx,
                        self.len_epoch,
                        self.train_metrics.avg('loss'),
                        self.train_iou_metrics.get_iou(),
                        self.train_teacher_iou_metrics.get_iou(),
                        self.train_metrics.avg('supervised_loss'),
                        self.train_metrics.avg('kd_loss'),
                        self.train_metrics.avg('hint_loss'),
                        self.train_metrics.avg('teacher_loss'),
                    ))

            if batch_idx == self.len_epoch:
                break

        log = self.train_metrics.result()
        log.update(
            {'train_teacher_mIoU': self.train_teacher_iou_metrics.get_iou()})
        log.update({'train_student_mIoU': self.train_iou_metrics.get_iou()})

        if self.do_validation and (
            (epoch % self.config["trainer"]["do_validation_interval"]) == 0):
            val_log = self._valid_epoch(epoch)
            log.update(**{'val_' + k: v for k, v in val_log.items()})
            log.update(**{'val_mIoU': self.valid_iou_metrics.get_iou()})
            self.val_iou_tracker.update(self.valid_iou_metrics.get_iou())

        self._teacher_student_iou_gap = self.train_teacher_iou_metrics.get_iou(
        ) - self.train_iou_metrics.get_iou()

        # step lr scheduler
        if (self.lr_scheduler is not None) and (not isinstance(
                self.lr_scheduler, MyOneCycleLR)):
            if isinstance(self.lr_scheduler, MyReduceLROnPlateau):
                self.lr_scheduler.step(self.train_metrics.avg('loss'))
            else:
                self.lr_scheduler.step()
                self.logger.debug('stepped lr')
                for param_group in self.optimizer.param_groups:
                    self.logger.debug(param_group['lr'])

        # anneal weight between losses
        self.weight_scheduler.step()

        return log

    def _valid_epoch(self, epoch):
        """
        Validate after training an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        self._clean_cache()
        # FIXME:
        # as the teacher network contain batchnorm layer and our resources are limited to train with
        # large batch size we ALWAYS keep bn as training mode to prevent instable problem when having
        # small batch size
        # self.model.eval()
        self.model.save_hidden = False  # stop saving hidden output
        self.valid_metrics.reset()
        self.valid_iou_metrics.reset()
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.valid_data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model.inference(data)
                supervised_loss = self.criterions[0](output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'valid')
                self.valid_metrics.update('supervised_loss',
                                          supervised_loss.item())
                self.valid_iou_metrics.update(output.detach().cpu(), target)
                self.logger.debug(
                    str(batch_idx) + " : " +
                    str(self.valid_iou_metrics.get_iou()))

                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__,
                                              met(output, target))
        result = self.valid_metrics.result()
        result['mIoU'] = self.valid_iou_metrics.get_iou()

        return result

    def _test_epoch(self, epoch):
        # cleaning up memory
        self._clean_cache()
        # self.model.eval()
        self.model.save_hidden = False
        self.model.cpu()
        self.model.student.to(self.device)

        # prepare before running submission
        self.test_metrics.reset()
        self.test_iou_metrics.reset()
        args = self.config['test']['args']
        save_4_sm = self.config['submission']['save_output']
        path_output = self.config['submission']['path_output']
        if save_4_sm and not os.path.exists(path_output):
            os.mkdir(path_output)
        n_samples = len(self.valid_data_loader)

        with torch.no_grad():
            for batch_idx, (img_name, data,
                            target) in enumerate(self.valid_data_loader):
                self.logger.info('{}/{}'.format(batch_idx, n_samples))
                data, target = data.to(self.device), target.to(self.device)
                output = self.model.inference_test(data, args)
                if save_4_sm:
                    self.save_for_submission(output, img_name[0])
                supervised_loss = self.criterions[0](output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_data_loader) + batch_idx,
                    'test')
                self.test_metrics.update('supervised_loss',
                                         supervised_loss.item())
                self.test_iou_metrics.update(output.detach().cpu(), target)

                for met in self.metric_ftns:
                    self.test_metrics.update(met.__name__, met(output, target))

        result = self.test_metrics.result()
        result['mIoU'] = self.test_iou_metrics.get_iou()

        return result

    def save_for_submission(self, output, image_name, img_type=np.uint8):
        args = self.config['submission']
        path_output = args['path_output']
        image_save = '{}.{}'.format(image_name, args['ext'])
        path_save = os.path.join(path_output, image_save)
        result = torch.argmax(output, dim=1)
        result_mapped = self.re_map_for_submission(result)
        if output.size()[0] == 1:
            result_mapped = result_mapped[0]

        save_image(result_mapped.cpu().numpy().astype(img_type), path_save)
        print('Saved output of test data: {}'.format(image_save))

    def re_map_for_submission(self, output):
        mapping = self.valid_data_loader.dataset.id_to_trainid
        cp_output = torch.zeros(output.size())
        for k, v in mapping.items():
            cp_output[output == v] = k

        return cp_output

    def _clean_cache(self):
        self.model.student_hidden_outputs, self.model.teacher_hidden_outputs = list(
        ), list()
        gc.collect()
        torch.cuda.empty_cache()

    def resume(self, checkpoint_path):
        self.logger.info("Loading checkpoint: {} ...".format(checkpoint_path))
        checkpoint = torch.load(checkpoint_path,
                                map_location=torch.device('cpu'))
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        config = checkpoint['config']  # config of checkpoint
        epoch = checkpoint['epoch']  # stopped epoch

        # load model state from checkpoint
        # first, align the network by replacing depthwise separable for student
        for i in range(1, epoch + 1):
            self.prepare_train_epoch(i, config)
        # load weight
        forgiving_state_restore(self.model, checkpoint['state_dict'])
        self.logger.info("Loaded model's state dict")

        # load optimizer state from checkpoint only when optimizer type is not changed.
        if checkpoint['config']['optimizer']['type'] != self.config[
                'optimizer']['type']:
            self.logger.warning(
                "Warning: Optimizer type given in config file is different from that of checkpoint. "
                "Optimizer parameters not being resumed.")
        else:
            self.optimizer.load_state_dict(checkpoint['optimizer'])
            self.logger.info("Loaded optimizer state dict")