Exemplo n.º 1
0
def train(model, args, device, writer, optimizer, data_loader, epoch):

    # Set train mode
    model.train()

    criterion = nn.CrossEntropyLoss(reduction='mean')
    metric_ftns = ['loss', 'correct', 'total', 'accuracy', 'sens', 'ppv']
    metrics = MetricTracker(*[m for m in metric_ftns],
                            writer=writer,
                            mode='train')
    metrics.reset()

    cm = torch.zeros(args.classes, args.classes)

    for batch_idx, input_tensors in enumerate(data_loader):

        input_data, target = input_tensors[0].to(device), input_tensors[1].to(
            device)

        # Forward
        output = model(input_data)
        loss = criterion(output, target)

        correct, total, acc = accuracy(output, target)
        update_confusion_matrix(cm, output, target)
        metrics.update_all_metrics({
            'correct': correct,
            'total': total,
            'loss': loss.item(),
            'accuracy': acc
        })

        # Backward
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # Save TB stats
        writer_step = (epoch - 1) * len(data_loader) + batch_idx
        if ((batch_idx + 1) % args.log_interval == 0):

            # Calculate confusion for this bucket
            ppv, sens = update_confusion_calc(cm)
            metrics.update_all_metrics({'sens': sens, 'ppv': ppv})
            cm = torch.zeros(args.classes, args.classes)

            metrics.write_tb(writer_step)

            num_samples = batch_idx * args.batch_size
            print_stats(args, epoch, num_samples, data_loader, metrics)

    return metrics, writer_step
Exemplo n.º 2
0
def validation(args, model, testloader, epoch, writer):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = [
        'loss', 'correct', 'total', 'accuracy', 'ppv', 'sensitivity'
    ]
    val_metrics = MetricTracker(*[m for m in metric_ftns],
                                writer=writer,
                                mode='val')
    val_metrics.reset()
    confusion_matrix = torch.zeros(args.class_dict, args.class_dict)
    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors
            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()

            output = model(input_data)

            loss = criterion(output, target)

            correct, total, acc = accuracy(output, target)
            num_samples = batch_idx * args.batch_size + 1
            _, pred = torch.max(output, 1)

            num_samples = batch_idx * args.batch_size + 1
            for t, p in zip(target.cpu().view(-1), pred.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            val_metrics.update_all_metrics(
                {
                    'correct': correct,
                    'total': total,
                    'loss': loss.item(),
                    'accuracy': acc
                },
                writer_step=(epoch - 1) * len(testloader) + batch_idx)

    print_summary(args, epoch, num_samples, val_metrics, mode="Validation")
    s = sensitivity(confusion_matrix.numpy())
    ppv = positive_predictive_value(confusion_matrix.numpy())
    print(f" s {s} ,ppv {ppv}")
    val_metrics.update('sensitivity',
                       s,
                       writer_step=(epoch - 1) * len(testloader) + batch_idx)
    val_metrics.update('ppv',
                       ppv,
                       writer_step=(epoch - 1) * len(testloader) + batch_idx)
    print('Confusion Matrix\n{}'.format(confusion_matrix.cpu().numpy()))
    return val_metrics, confusion_matrix
Exemplo n.º 3
0
def train(args, model, trainloader, optimizer, epoch):

    start_time = time.time()
    model.train()

    train_metrics = MetricTracker(*[m for m in METRICS_TRACKED], mode='train')
    w2 = torch.Tensor([1.0, 1.0, 1.5])

    if (args.cuda):
        model.cuda()
        w2 = w2.cuda()

    train_metrics.reset()
    # JUST FOR CHECK
    counter_batches = 0
    counter_covid = 0

    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        counter_batches += 1

        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()

        output = model(input_data)

        loss, counter = weighted_loss(output, target, w2)
        counter_covid += counter
        loss.backward()

        optimizer.step()
        correct, total, acc = accuracy(output, target)
        precision_mean, recall_mean = precision_score(output, target)

        num_samples = batch_idx * args.batch_size + 1
        train_metrics.update_all_metrics(
            {
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc,
                'precision_mean': precision_mean,
                'recall_mean': recall_mean
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        print_stats(args, epoch, num_samples, trainloader, train_metrics)
    print("--- %s seconds ---" % (time.time() - start_time))
    print_summary(args, epoch, num_samples, train_metrics, mode="Training")
    return train_metrics
Exemplo n.º 4
0
def validation(args, model, testloader, epoch):

    model.eval()

    val_metrics = MetricTracker(*[m for m in METRICS_TRACKED], mode='val')
    val_metrics.reset()
    w2 = torch.Tensor([1.0, 1.0,
                       1.5])  #w_full = torch.Tensor([1.456,1.0,15.71])

    if (args.cuda):
        w2 = w2.cuda()

    confusion_matrix = torch.zeros(args.classes, args.classes)

    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors

            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()

            output = model(input_data)

            loss, counter = weighted_loss(output, target, w2)
            correct, total, acc = accuracy(output, target)
            precision_mean, recall_mean = precision_score(output, target)

            num_samples = batch_idx * args.batch_size + 1
            _, preds = torch.max(output, 1)

            for t, p in zip(target.cpu().view(-1), preds.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            val_metrics.update_all_metrics(
                {
                    'correct': correct,
                    'total': total,
                    'loss': loss.item(),
                    'accuracy': acc,
                    'precision_mean': precision_mean,
                    'recall_mean': recall_mean
                },
                writer_step=(epoch - 1) * len(testloader) + batch_idx)

    print_summary(args, epoch, num_samples, val_metrics, mode="Validation")
    print('Confusion Matrix\n {}'.format(confusion_matrix.cpu().numpy()))

    return val_metrics, confusion_matrix
Exemplo n.º 5
0
def train(args, model, trainloader, optimizer, epoch, writer, log):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = [
        'loss', 'correct', 'total', 'accuracy', 'ppv', 'sensitivity'
    ]
    train_metrics = MetricTracker(*[m for m in metric_ftns],
                                  writer=writer,
                                  mode='train')
    train_metrics.reset()
    confusion_matrix = torch.zeros(args.class_dict, args.class_dict)

    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()

        output = model(input_data)

        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        correct, total, acc = accuracy(output, target)
        pred = torch.argmax(output, dim=1)

        num_samples = batch_idx * args.batch_size + 1
        train_metrics.update_all_metrics(
            {
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        print_stats(args, epoch, num_samples, trainloader, train_metrics)
        for t, p in zip(target.cpu().view(-1), pred.cpu().view(-1)):
            confusion_matrix[t.long(), p.long()] += 1
    s = sensitivity(confusion_matrix.numpy())
    ppv = positive_predictive_value(confusion_matrix.numpy())
    print(f" s {s} ,ppv {ppv}")
    # train_metrics.update('sensitivity', s, writer_step=(epoch - 1) * len(trainloader) + batch_idx)
    # train_metrics.update('ppv', ppv, writer_step=(epoch - 1) * len(trainloader) + batch_idx)
    print_summary(args, epoch, num_samples, train_metrics, mode="Training")
    return train_metrics
Exemplo n.º 6
0
def val(args, model, data_loader, epoch, writer, device):

    model.eval()

    criterion = nn.CrossEntropyLoss(reduction='mean')
    metric_ftns = ['loss', 'correct', 'total', 'accuracy', 'ppv', 'sens']
    metrics = MetricTracker(*[m for m in metric_ftns],
                            writer=writer,
                            mode='val')
    metrics.reset()

    cm = torch.zeros(args.classes, args.classes)

    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(data_loader):
            torch.cuda.empty_cache()
            input_data, target = input_tensors[0].to(
                device), input_tensors[1].to(device)

            # Forward
            output = model(input_data)
            loss = criterion(output, target)

            correct, total, acc = accuracy(output, target)
            update_confusion_matrix(cm, output, target)

            # Update the metrics record
            metrics.update_all_metrics({
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc
            })

        ppv, sens = update_confusion_calc(cm)
        metrics.update_all_metrics({'sens': sens, 'ppv': ppv})

    return metrics, cm
Exemplo n.º 7
0
def train(args, model, trainloader, optimizer, epoch, writer):
    model.train()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = ['loss', 'correct', 'total', 'accuracy']
    train_metrics = MetricTracker(*[m for m in metric_ftns],
                                  writer=writer,
                                  mode='train')
    train_metrics.reset()

    for batch_idx, input_tensors in enumerate(trainloader):
        optimizer.zero_grad()
        input_data, target = input_tensors
        if (args.cuda):
            input_data = input_data.cuda()
            target = target.cuda()

        output = model(input_data)

        loss = criterion(output, target)
        loss.backward()

        optimizer.step()
        correct, total, acc = accuracy(output, target)

        num_samples = batch_idx * args.batch_size + 1
        train_metrics.update_all_metrics(
            {
                'correct': correct,
                'total': total,
                'loss': loss.item(),
                'accuracy': acc
            },
            writer_step=(epoch - 1) * len(trainloader) + batch_idx)
        print_stats(args, epoch, num_samples, trainloader, train_metrics)

    print_summary(args, epoch, num_samples, train_metrics, mode="Training")
    return train_metrics
Exemplo n.º 8
0
class Trainer(BaseTrainer):
	"""
	Trainer class
	"""
	def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
				 data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
		super().__init__(model, criterion, metric_ftns, optimizer, config)
		self.config = config
		self.device = device
		self.data_loader = data_loader
		if len_epoch is None:
			# epoch-based training
			self.len_epoch = len(self.data_loader)
		else:
			# iteration-based training
			self.data_loader = inf_loop(data_loader)
			self.len_epoch = len_epoch
		self.valid_data_loader = valid_data_loader
		self.do_validation = self.valid_data_loader is not None
		self.lr_scheduler = lr_scheduler
		self.log_step = int(np.sqrt(data_loader.batch_size))
		
		self.best_valid = 0
		self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
		self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

		self.file_metrics = str(self.checkpoint_dir / 'metrics.csv')
		


	def _train_epoch(self, epoch):
		"""
		Training logic for an epoch
		:param epoch: Integer, current training epoch.
		:return: A log that contains average loss and metric in this epoch.
		"""
		
		self.model.train()
		self.train_metrics.reset()
		for batch_idx, (data, target) in enumerate(self.data_loader):
			data, target = data.to(self.device), target.to(self.device)

			self.optimizer.zero_grad()
			output = self.model(data)
			loss = self.criterion(output, target)
			loss.backward()
			self.optimizer.step()

			self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
			self.train_metrics.update('loss', loss.item())
			for met in self.metric_ftns:
				self.train_metrics.update(met.__name__, met(output, target))

			if batch_idx % self.log_step == 0:
				"""self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
					epoch,
					self._progress(batch_idx),
					loss.item()))
				"""
				self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

			if batch_idx == self.len_epoch:
				break
		log = self.train_metrics.result()

		if self.do_validation:
			val_log = self._valid_epoch(epoch)
			log.update(**{'val_'+k : v for k, v in val_log.items()})

		if self.lr_scheduler is not None:
			self.lr_scheduler.step()


		output = self.model(data)
		loss = self.criterion(output, target)	

		self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
		self.writer.add_scalar('Loss',  loss)

		self._save_csv(epoch, log)

		return log

	def _valid_epoch(self, epoch):
		"""
		Validate after training an epoch
		:param epoch: Integer, current training epoch.
		:return: A log that contains information about validation
		"""
		self.model.eval()
		self.valid_metrics.reset()
		with torch.no_grad():
			for batch_idx, (data, target) in enumerate(self.valid_data_loader):
				data, target = data.to(self.device), target.to(self.device)

				output = self.model(data)
				loss = self.criterion(output, target)

				self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
				self.valid_metrics.update('loss', loss.item())
				for met in self.metric_ftns:
					self.valid_metrics.update(met.__name__, met(output, target))
				self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

		# add histogram of model parameters to the tensorboard
		for name, p in self.model.named_parameters():
			self.writer.add_histogram(name, p, bins='auto')

		# we added spùe custom here
		output = self.model(data)
		loss = self.criterion(output, target)
		self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
		self.writer.add_scalar('Loss',  loss)

		val_log = self.valid_metrics.result()
		actual_accu = val_log['accuracy']
		if(actual_accu - self.best_valid > 0.0025 and self.save):
			self.best_valid = actual_accu
			if self.tensorboard: # is true you can use tensorboard
				self._save_checkpoint(epoch, save_best=True)
			filename = str(self.checkpoint_dir / 'checkpoint-best-epoch.pth')
			torch.save(self.model.state_dict(), filename)
			self.logger.info("Saving checkpoint: {} ...".format(filename))

		return val_log

	def _progress(self, batch_idx):
		base = '[{}/{} ({:.0f}%)]'
		if hasattr(self.data_loader, 'n_samples'):
			current = batch_idx * self.data_loader.batch_size
			total = self.data_loader.n_samples
		else:
			current = batch_idx
			total = self.len_epoch
		return base.format(current, total, 100.0 * current / total)


	def _save_csv(self, epoch ,log):
		"""
			Saving checkpoints
			:param epoch: current epoch number
			:param log: logging information of the epoch
		"""

		fichier = open(self.file_metrics, "a")

		if epoch == 1:
			fichier.write("epoch,")
			for key in log:
				fichier.write(str(key) +",")
			fichier.write("\n")

		fichier.write(str(epoch) +",")
		for key in log:
			fichier.write(str(log[key]) + ",")
		fichier.write("\n")
		fichier.close()
Exemplo n.º 9
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 config,
                 model,
                 optimizer,
                 data_loader,
                 writer,
                 checkpoint_dir,
                 logger,
                 class_dict,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 metric_ftns=None):
        super(Trainer, self).__init__(config,
                                      data_loader,
                                      writer,
                                      checkpoint_dir,
                                      logger,
                                      valid_data_loader=valid_data_loader,
                                      test_data_loader=test_data_loader,
                                      metric_ftns=metric_ftns)

        if (self.config.cuda):
            use_cuda = torch.cuda.is_available()
            self.device = torch.device("cuda" if use_cuda else "cpu")
        else:
            self.device = torch.device("cpu")
        self.start_epoch = 1
        self.train_data_loader = data_loader

        self.len_epoch = self.config.dataloader.train.batch_size * len(
            self.train_data_loader)
        self.epochs = self.config.epochs
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = self.config.log_interval
        self.model = model
        self.num_classes = len(class_dict)
        self.optimizer = optimizer

        self.mnt_best = np.inf
        if self.config.dataset.type == 'multi_target':
            self.criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
        else:
            self.criterion = torch.nn.CrossEntropyLoss(reduction='mean')
        self.checkpoint_dir = checkpoint_dir
        self.gradient_accumulation = config.gradient_accumulation
        self.writer = writer
        self.metric_ftns = ['loss', 'acc']
        self.train_metrics = MetricTracker(*[m for m in self.metric_ftns],
                                           writer=self.writer,
                                           mode='train')
        self.metric_ftns = ['loss', 'acc']
        self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns],
                                           writer=self.writer,
                                           mode='validation')
        self.logger = logger

        self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        Args:
            epoch (int): current training epoch.
        """

        self.model.train()
        self.confusion_matrix = 0 * self.confusion_matrix
        self.train_metrics.reset()
        gradient_accumulation = self.gradient_accumulation
        for batch_idx, (data, target) in enumerate(self.train_data_loader):

            data = data.to(self.device)

            target = target.to(self.device)

            output = self.model(data)
            loss = self.criterion(output, target)
            loss = loss.mean()

            (loss / gradient_accumulation).backward()
            if (batch_idx % gradient_accumulation == 0):
                self.optimizer.step()  # Now we can do an optimizer step
                self.optimizer.zero_grad()  # Reset gradients tensors

            prediction = torch.max(output, 1)

            writer_step = (epoch - 1) * self.len_epoch + batch_idx

            self.train_metrics.update(key='loss',
                                      value=loss.item(),
                                      n=1,
                                      writer_step=writer_step)
            self.train_metrics.update(
                key='acc',
                value=np.sum(prediction[1].cpu().numpy() == target.squeeze(
                    -1).cpu().numpy()),
                n=target.size(0),
                writer_step=writer_step)
            for t, p in zip(target.cpu().view(-1),
                            prediction[1].cpu().view(-1)):
                self.confusion_matrix[t.long(), p.long()] += 1
            self._progress(batch_idx,
                           epoch,
                           metrics=self.train_metrics,
                           mode='train')

        self._progress(batch_idx,
                       epoch,
                       metrics=self.train_metrics,
                       mode='train',
                       print_summary=True)

    def _valid_epoch(self, epoch, mode, loader):
        """

        Args:
            epoch (int): current epoch
            mode (string): 'validation' or 'test'
            loader (dataloader):

        Returns: validation loss

        """
        self.model.eval()
        self.valid_sentences = []
        self.valid_metrics.reset()
        self.confusion_matrix = 0 * self.confusion_matrix
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(loader):
                data = data.to(self.device)

                target = target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)
                loss = loss.mean()
                writer_step = (epoch - 1) * len(loader) + batch_idx

                prediction = torch.max(output, 1)
                acc = np.sum(prediction[1].cpu().numpy() == target.squeeze(
                    -1).cpu().numpy()) / target.size(0)

                self.valid_metrics.update(key='loss',
                                          value=loss.item(),
                                          n=1,
                                          writer_step=writer_step)
                self.valid_metrics.update(
                    key='acc',
                    value=np.sum(prediction[1].cpu().numpy() == target.squeeze(
                        -1).cpu().numpy()),
                    n=target.size(0),
                    writer_step=writer_step)
                for t, p in zip(target.cpu().view(-1),
                                prediction[1].cpu().view(-1)):
                    self.confusion_matrix[t.long(), p.long()] += 1
        self._progress(batch_idx,
                       epoch,
                       metrics=self.valid_metrics,
                       mode=mode,
                       print_summary=True)

        s = sensitivity(self.confusion_matrix.numpy())
        ppv = positive_predictive_value(self.confusion_matrix.numpy())
        print(f" s {s} ,ppv {ppv}")
        val_loss = self.valid_metrics.avg('loss')

        return val_loss

    def train(self):
        """
        Train the model
        """
        for epoch in range(self.start_epoch, self.epochs):
            torch.manual_seed(self.config.seed)
            self._train_epoch(epoch)

            self.logger.info(f"{'!' * 10}    VALIDATION   , {'!' * 10}")
            validation_loss = self._valid_epoch(epoch, 'validation',
                                                self.valid_data_loader)
            make_dirs(self.checkpoint_dir)

            self.checkpointer(epoch, validation_loss)
            self.lr_scheduler.step(validation_loss)
            if self.do_test:
                self.logger.info(f"{'!' * 10}    VALIDATION   , {'!' * 10}")
                self.predict(epoch)

    def predict(self, epoch):
        """
        Inference
        Args:
            epoch ():

        Returns:

        """
        self.model.eval()

        predictions = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                data = data.to(self.device)

                logits = self.model(data, None)

                maxes, prediction = torch.max(
                    logits, 1)  # get the index of the max log-probability
                # log.info()
                predictions.append(
                    f"{target[0]},{prediction.cpu().numpy()[0]}")

        pred_name = os.path.join(
            self.checkpoint_dir,
            f'validation_predictions_epoch_{epoch:d}_.csv')
        write_csv(predictions, pred_name)
        return predictions

    def checkpointer(self, epoch, metric):

        is_best = metric < self.mnt_best
        if (is_best):
            self.mnt_best = metric

            self.logger.info(f"Best val loss {self.mnt_best} so far ")
            # else:
            #     self.gradient_accumulation = self.gradient_accumulation // 2
            #     if self.gradient_accumulation < 4:
            #         self.gradient_accumulation = 4

            save_model(self.checkpoint_dir, self.model, self.optimizer,
                       self.valid_metrics.avg('loss'), epoch, f'_model_best')
        save_model(self.checkpoint_dir, self.model, self.optimizer,
                   self.valid_metrics.avg('loss'), epoch, f'_model_last')

    def _progress(self,
                  batch_idx,
                  epoch,
                  metrics,
                  mode='',
                  print_summary=False):
        metrics_string = metrics.calc_all_metrics()
        if ((batch_idx * self.config.dataloader.train.batch_size) %
                self.log_step == 0):

            if metrics_string == None:
                self.logger.warning(f" No metrics")
            else:
                self.logger.info(
                    f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Sample [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}"
                )
        elif print_summary:
            self.logger.info(
                f'{mode} summary  Epoch: [{epoch}/{self.epochs}]\t {metrics_string}'
            )
Exemplo n.º 10
0
class MNISTTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_fns,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_fns, optimizer, config,
                         device, data_loader, valid_data_loader, lr_scheduler)

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        # Define evaluator
        self.evaluator = MNISTTester(self.model, self.criterion,
                                     self.metric_ftns, self.config,
                                     self.device, self.valid_data_loader, True)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """

        self.model.train()
        self.train_metrics.reset()
        with tqdm(total=self.data_loader.n_samples) as progbar:
            for batch_idx, (data, target) in enumerate(self.data_loader):
                data, target = data.to(self.device), target.to(self.device)

                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()

                self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
                self.train_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.train_metrics.update(met.__name__,
                                              met(output, target))

                if batch_idx % self.log_step == 0:
                    self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                        epoch, self._progress(batch_idx), loss.item()))
                    self.writer.add_image(
                        'input', make_grid(data.cpu(), nrow=8, normalize=True))

                if batch_idx == self.len_epoch:
                    break
                progbar.update(self.data_loader.init_kwargs['batch_size'])
                epoch_part = str(epoch) + '/' + str(self.epochs)
                progbar.set_postfix(epoch=epoch_part, NLL=loss.item())
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self.evaluator.test()
            log.update(**{'val_' + k: round(v, 5) for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Exemplo n.º 11
0
class Seq2SeqSimpleTrainer(BaseTrainer):
    """
    Trainer for a simple seq2seq mode.
    """
    def __init__(self,
                 model,
                 criterion,
                 train_metric_ftns,
                 eval_metric_fns,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 validate_only=False):
        """

        :param model: The model to train.
        :param criterion: we ignore this value and overwrite it
        :param train_metric_ftns: The metric function names to use for training.
        :param eval_metric_fns: The metric function names to use for evaluating.
        :param optimizer: The optimizer to use.
        :param config: The configuration file for the run.
        :param device: The device to train on.
        :param data_loader: The training data loader to use.
        :param valid_data_loader: The validation data loader to use.
        :param lr_scheduler: scheduler for the learning rate.
        :param len_epoch: The amount of examples in an epoch.
        :param validate_only: use if resumed, only run validation on the last resumed checkpoint.
        """
        self.vocab = model.vocab
        self.pad_idx = self.vocab['<pad>']

        self.criterion = criterion
        super().__init__(model, self.criterion, train_metric_ftns,
                         eval_metric_fns, optimizer, config, device,
                         data_loader, valid_data_loader, lr_scheduler)

        self.question_pad_length = config['data_loader']['question_pad_length']
        self.qdmr_pad_length = config['data_loader']['qdmr_pad_length']
        self.lexicon_pad_length = config['data_loader']['lexicon_pad_length']
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch

        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.train_metric_ftns],
            writer=self.writer)

        # Define evaluator.
        self.evaluator = Seq2SeqSimpleTester(self.model, self.criterion,
                                             self.eval_metric_ftns,
                                             self.config, self.device,
                                             self.valid_data_loader, True)

        # Run validation and exit.
        if validate_only:
            val_log = self.evaluator.test()
            log = {'val_' + k: round(v, 5) for k, v in val_log.items()}
            print(log)
            exit()

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch.

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        # Sets the model to training mode.
        self.model.train()
        self.train_metrics.reset()
        convert_to_program = self.data_loader.gold_type_is_qdmr()

        with tqdm(total=len(self.data_loader)) as progbar:
            for batch_idx, (_, data, target,
                            lexicon_str) in enumerate(self.data_loader):
                data, mask_data = batch_to_tensor(self.vocab, data,
                                                  self.question_pad_length,
                                                  self.device)
                target, mask_target = batch_to_tensor(self.vocab, target,
                                                      self.qdmr_pad_length,
                                                      self.device)
                lexicon_ids, mask_lexicon = tokenize_lexicon_str(
                    self.vocab, lexicon_str, self.qdmr_pad_length, self.device)
                # Run the model on the batch
                self.optimizer.zero_grad()
                # out shape is (batch_size, seq_len, output_size)

                output, mask_output = self.model(data, target, lexicon_ids)

                # CEloss expects (minibatch, classes, seq_len)
                # out after transpose is (batch_size, output_size, seq_len)
                # output = torch.transpose(output, 1, 2)

                # Calculate the loss and perform optimization step.
                # TODO test properly use of masks
                # output dims should be (batch_size, num_decoding_steps, num_classes)
                loss = self.criterion(output, mask_output, target, mask_target)
                loss.backward()
                self.optimizer.step()

                self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)

                with torch.no_grad():
                    pred = torch.argmax(output, dim=1)
                    # data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False)
                    # target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program)
                    # pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program)

                # Update metrics
                self.train_metrics.update('loss', loss.item())
                # for met in self.metric_ftns:
                #     self.train_metrics.update(met.__name__, met(pred_str, target_str, data_str))

                # Log progress
                if batch_idx % self.log_step == 0:
                    self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                        epoch, self._progress(batch_idx), loss.item()))
                    # TODO set this to write the text examples or remove
                    # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

                if batch_idx == self.len_epoch:
                    break

                # Update the progress bar.
                progbar.update(1)
                epoch_part = str(epoch) + '/' + str(self.epochs)
                progbar.set_postfix(
                    epoch=epoch_part,
                    LOSS=loss.item(),
                    batch_size=self.data_loader.init_kwargs['batch_size'],
                    samples=self.data_loader.n_samples)

        # Save the calculated metrics for that epoch.
        log = self.train_metrics.result()

        # If validation split exists, evaluate on validation set as well.
        if self.do_validation:
            # TODO print epoch stuff and add epoch to writer
            # TODO self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
            val_log = self.evaluator.test()
            log.update(**{'val_' + k: round(v, 5) for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Exemplo n.º 12
0
class Trainer():
    def __init__(self,
                 model,
                 criterion,
                 metrics_name,
                 optimizer,
                 train_loader,
                 logger,
                 log_dir,
                 nb_epochs,
                 save_dir,
                 device="cuda:0",
                 log_step=10,
                 start_epoch=0,
                 enable_tensorboard=True,
                 valid_loader=None,
                 lr_scheduler=None,
                 monitor="min val_loss",
                 early_stop=10,
                 save_epoch_period=1,
                 resume=""):
        self.model = model
        self.criterion = criterion
        self.metrics_name = metrics_name
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        self.len_epoch = len(self.train_loader)
        self.do_validation = (self.valid_loader is not None)
        self.lr_scheduler = lr_scheduler
        self.log_step = log_step
        self.epochs = nb_epochs
        self.start_epoch = start_epoch + 1

        self.logger = logger
        self.device = device
        self.save_period = save_epoch_period

        self.writer = TensorboardWriter(log_dir, self.logger,
                                        enable_tensorboard)
        self.train_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.valid_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.checkpoint_dir = save_dir
        if monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = monitor.split()
            assert self.mnt_mode in ['min', 'max']
            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = early_stop
        if resume != "":
            self._resume_checkpoint(resume_path=resume)
        self.model.to(self.device)

    def train(self):
        not_improved_count = 0

        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)
            log = {'epoch': epoch}
            log.update(result)
            self.logger.info('    {:15s}: {}'.format(str("mnt best"),
                                                     self.mnt_best))
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False
                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1
                if (not_improved_count > self.early_stop) and (self.early_stop
                                                               > 0):
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, best)

    def _train_epoch(self, epoch):
        self.model.train()
        self.train_metrics.reset()
        start_time = time.time()

        for batch_idx, sample in enumerate(self.train_loader):
            data = sample['image']
            target = sample['mask']
            data, target = data.to(self.device), target.to(self.device)
            current_lr = self.lr_scheduler(self.optimizer, batch_idx, epoch)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met_name in self.metrics_name:
                self.train_metrics.update(
                    met_name,
                    getattr(metrics, met_name)(output, target))
            if batch_idx % self.log_step == 0:
                time_to_run = time.time() - start_time
                start_time = time.time()
                speed = self.log_step / time_to_run
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f} LR: {:.6f}  Speed: {:.4f}iters/s' \
                                  .format(epoch, self._progress(batch_idx), loss.item(), current_lr, speed))
                for met_name in self.metrics_name:
                    self.writer.add_scalar(met_name,
                                           self.train_metrics.avg(met_name))
                self.writer.add_scalar('loss', self.train_metrics.avg('loss'))
                self.writer.add_scalar("lr", current_lr)
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
            assert batch_idx <= self.len_epoch
        log = self.train_metrics.result()
        if self.do_validation:
            print("Start validation")
            val_log, iou_classes = self._valid_epoch(epoch)

            log.update(**{'val_' + k: v for k, v in val_log.items()})
            for key, value in iou_classes.items():
                log.update({key: value})
        return log

    def _valid_epoch(self, epoch):
        self.model.eval()
        self.valid_metrics.reset()
        iou_tracker = metrics.IoU(2)
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.valid_loader):
                data = sample['image']
                target = sample['mask']
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                target = target.cpu().numpy()
                output = output[:, 0]
                output = output.data.cpu().numpy()
                pred = np.zeros_like(output)
                pred[output > 0.5] = 1
                pred = pred.astype(np.int64)
                for i in range(len(target)):
                    iou_tracker.add_batch(target[i], pred[i])
        iou_classes = iou_tracker.get_iou()
        for key, value in iou_classes.items():
            self.writer.add_scalar(key, value)
        self.writer.add_scalar('val_loss', self.valid_metrics.avg('loss'))

        for met_name in self.metrics_name:
            self.writer.add_scalar(met_name, self.valid_metrics.avg(met_name))

        # for name, p in self.model.named_parameters():
        #     print(name, p)
        #     self.writer.add_histogram(name, p.cpu().data.numpy(), bins='auto')
        #
        return self.valid_metrics.result(), iou_classes

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        current = batch_idx
        total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            # 'config': self.config
        }
        filename = str(self.checkpoint_dir /
                       'checkpoint-epoch{:06d}.pth'.format(epoch))
        torch.save(state, filename)
        self.delete_checkpoint()
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def delete_checkpoint(self):
        checkpoints_file = list(
            self.checkpoint_dir.glob("checkpoint-epoch*.pth"))
        checkpoints_file.sort()
        for checkpoint_file in checkpoints_file[:-5]:
            os.remove(str(checkpoint_file.absolute()))

    def _resume_checkpoint(self, resume_path):
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
Exemplo n.º 13
0
class Tester(BaseTrainer):
    """
    Trainer class
    """

    def __init__(self, config, model, data_loader, writer, checkpoint_dir, logger,
                 valid_data_loader=None, test_data_loader=None, metric_ftns=None):
        super(Tester, self).__init__(config, data_loader, writer, checkpoint_dir, logger,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader, metric_ftns=metric_ftns)

        if (self.config.cuda):
            use_cuda = torch.cuda.is_available()
            self.device = torch.device("cuda" if use_cuda else "cpu")
        else:
            self.device = torch.device("cpu")
        self.start_epoch = 1

        self.epochs = self.config.epochs
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None

        self.log_step = self.config.log_interval
        self.model = model

        self.mnt_best = np.inf

        self.checkpoint_dir = checkpoint_dir
        self.gradient_accumulation = config.gradient_accumulation
        self.metric_ftns = ['loss', 'acc']
        self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns], writer=self.writer, mode='validation')
        self.logger = logger
    def _valid_epoch(self, epoch, mode, loader):
        """

        Args:
            epoch (int): current epoch
            mode (string): 'validation' or 'test'
            loader (dataloader):

        Returns: validation loss

        """
        self.model.eval()
        self.valid_sentences = []
        self.valid_metrics.reset()

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(loader):
                data = data.to(self.device)

                target = target.long().to(self.device)

                output, loss = self.model(data, target)
                loss = loss.mean()
                writer_step = (epoch - 1) * len(loader) + batch_idx

                prediction = torch.max(output, 1)
                acc = np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()) / target.size(0)

                self.valid_metrics.update(key='loss',value=loss.item(),n=1,writer_step=writer_step)
                self.valid_metrics.update(key='acc', value=np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()), n=target.size(0), writer_step=writer_step)

        self._progress(batch_idx, epoch, metrics=self.valid_metrics, mode=mode, print_summary=True)


        val_loss = self.valid_metrics.avg('loss')


        return val_loss


    def predict(self):
        """
        Inference
        Args:
            epoch ():

        Returns:

        """
        self.model.eval()

        predictions = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                data = data.to(self.device)

                logits = self.model(data, None)

                maxes, prediction = torch.max(logits, 1)  # get the index of the max log-probability

                predictions.append(f"{target[0]},{prediction.cpu().numpy()[0]}")
        self.logger.info('Inference done')
        pred_name = os.path.join(self.checkpoint_dir, f'predictions.csv')
        write_csv(predictions, pred_name)
        return predictions

    def _progress(self, batch_idx, epoch, metrics, mode='', print_summary=False):
        metrics_string = metrics.calc_all_metrics()
        if ((batch_idx * self.config.dataloader.train.batch_size) % self.log_step == 0):

            if metrics_string == None:
                self.logger.warning(f" No metrics")
            else:
                self.logger.info(
                    f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Video [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}")
        elif print_summary:
            self.logger.info(
                f'{mode} summary  Epoch: [{epoch}/{self.epochs}]\t {metrics_string}')
Exemplo n.º 14
0
class MNISTTester(BaseTester):
    """
    Trainer for a simple seq2seq mode.
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_fns,
                 config,
                 device,
                 data_loader,
                 evaluation=True):
        """

        :param model:
        :param criterion: we ignore this value and overwrite it
        :param metric_fns:
        :param optimizer:
        :param config:
        :param device:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param len_epoch:
        """

        super().__init__(model, criterion, metric_fns, config, device,
                         data_loader, evaluation)

        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _evaluate(self):
        """
        Validate after training an epoch.
        Used with  gold target

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        # Sets the model to evaluation mode.
        self.valid_metrics.reset()
        total_loss = 0.0
        total_metrics = torch.zeros(len(self.metric_ftns))

        for i, (data, target) in enumerate(tqdm(self.data_loader)):
            data, target = data.to(self.device), target.to(self.device)
            output = self.model(data)
            # computing loss, metrics on test set
            loss = self.criterion(output, target)
            batch_size = data.shape[0]
            total_loss += loss.item() * batch_size
            for i, metric in enumerate(self.metric_ftns):
                total_metrics[i] += metric(output, target) * batch_size

            self.valid_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.valid_metrics.update(met.__name__, met(output, target))
            self.writer.add_image(
                'input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _predict_without_target(self):
        return self._evaluate()
Exemplo n.º 15
0
class Seq2SeqSimpleTester(BaseTester):
    """
    Trainer for a simple seq2seq mode.
    """

    def __init__(self, model, criterion, metric_ftns, config, device,
                 data_loader, evaluation=True):
        """
        :param model: A model to test.
        :param criterion: we ignore this value and overwrite it
        :param metric_ftns: The names of the metric functions to use.
        :param config: The configuration.
        :param device: The device to use for the testing.
        :param data_loader: The dataloader to use for loading the testing data.
        """

        self.vocab = model.vocab
        self.question_pad_length = config['data_loader']['question_pad_length']
        self.qdmr_pad_length = config['data_loader']['qdmr_pad_length']
        self.lexicon_pad_length = config['data_loader']['lexicon_pad_length']
        self.pad_idx = self.vocab['<pad>']

        # Overriding the criterion.
        # self.criterion = CrossEntropyLoss(ignore_index=self.pad_idx)
        self.criterion = criterion
        super().__init__(model, self.criterion, metric_ftns, config, device, data_loader, evaluation)

        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

    def _evaluate(self):
        """
        Validate after training an epoch.
        Used with  gold target

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        # Choose 2 random examples from the dev set and print their prediction.
        batch_index1 = random.randint(0, len(self.data_loader) - 1) - 1
        example_index1 = random.randint(0, self.data_loader.batch_size - 1)
        batch_index2 = random.randint(0, len(self.data_loader) - 1) - 1
        example_index2 = random.randint(0, self.data_loader.batch_size - 1)
        questions = []
        decompositions = []
        targets = []
        convert_to_program = self.data_loader.gold_type_is_qdmr()

        # Sets the model to evaluation mode.
        self.valid_metrics.reset()
        with tqdm(total=len(self.data_loader)) as progbar:
            for batch_idx, (_, data, target, lexicon_str) in enumerate(self.data_loader):
                data, mask_data = batch_to_tensor(self.vocab, data, self.question_pad_length, self.device)
                target, mask_target = batch_to_tensor(self.vocab, target, self.qdmr_pad_length, self.device)
                lexicon_ids, mask_lexicon = tokenize_lexicon_str(self.vocab, lexicon_str, self.qdmr_pad_length, self.device)
                start = time.time()
                # Run the model on the batch and calculate the loss
                output, mask_output = self.model(data, target, lexicon_ids, evaluation_mode=True)
                loss = self.criterion(output, mask_output, target, mask_target)
                output = torch.transpose(output, 1, 2)
                pred = torch.argmax(output, dim=1)

                start = time.time()
                # Convert the predictions/ targets/questions from tensor of token_ids to list of strings.
                # TODO do we need to convert here or can we use the originals? (for data and target)
                data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False)
                target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program)
                pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program)

                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(pred_str, target_str, data_str))

                # Print example for predictions.
                if batch_idx == batch_index1:
                    questions.append(data_str[example_index1])
                    decompositions.append(pred_str[example_index1])
                    targets.append(target_str[example_index1])

                if batch_idx == batch_index2:
                    questions.append(data_str[example_index2])
                    decompositions.append(pred_str[example_index2])
                    targets.append(target_str[example_index2])

                # Update the progress bar.
                progbar.update(1)
                progbar.set_postfix(LOSS=loss.item(),
                                    batch_size=self.data_loader.init_kwargs['batch_size'],
                                    samples=self.data_loader.n_samples)

        # Print example predictions.
        for question, decomposition, target in zip(questions, decompositions, targets):
            print('\ndecomposition example:')
            print('question:\t\t', question)
            print('decomposition:\t', decomposition)
            print('target:\t\t\t', target)
            print()

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()


    def _predict_without_target(self):
        """
        get model predictions for testing.
        Used without targets

        :return: A log that contains information about predictions
        """
        qid_col = []
        pred_col = []
        question_col = []

        convert_to_program = self.data_loader.gold_type_is_qdmr()

        # Sets the model to evaluation mode.
        self.valid_metrics.reset()
        with tqdm(total=len(self.data_loader)) as progbar:
            for batch_idx, (question_ids, data, target, lexicon_str) in enumerate(self.data_loader):
                data, mask_data = batch_to_tensor(self.vocab, data, self.question_pad_length, self.device)
                target, mask_target = batch_to_tensor(self.vocab, target, self.qdmr_pad_length, self.device)
                lexicon_ids, mask_lexicon = tokenize_lexicon_str(self.vocab, lexicon_str, self.qdmr_pad_length, self.device)
                start = time.time()
                # Run the model on the batch and calculate the loss
                output, mask_output = self.model(data, target, lexicon_ids, evaluation_mode=True)
                loss = self.criterion(output, mask_output, target, mask_target)
                output = torch.transpose(output, 1, 2)
                pred = torch.argmax(output, dim=1)
                start = time.time()
                # Convert the predictions/ targets/questions from tensor of token_ids to list of strings.
                # TODO do we need to convert here or can we use the originals? (for data and target)

                data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False)
                target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program)
                pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program)

                for i, question_id in enumerate(question_ids):
                    self.logger.info('{}:{}'.format(question_id, data_str[i]))
                qid_col.extend(question_ids)
                pred_col.extend(pred_str)
                question_col.extend(data_str)

                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(pred_str, target_str, data_str))

                # Update the progress bar.
                progbar.update(1)
                progbar.set_postfix(LOSS=loss.item(),
                                    batch_size=self.data_loader.init_kwargs['batch_size'],
                                    samples=self.data_loader.n_samples)
        d = {'question_id': qid_col, 'question_text': question_col, 'decomposition': pred_col}
        programs_df = pd.DataFrame(data=d)
        programs_df.to_csv(self.predictions_file_name, index=False, encoding='utf-8')

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()