Esempio n. 1
0
def validation(args, model, testloader, epoch, writer):
    model.eval()
    criterion = nn.CrossEntropyLoss(reduction='mean')

    metric_ftns = [
        'loss', 'correct', 'total', 'accuracy', 'ppv', 'sensitivity'
    ]
    val_metrics = MetricTracker(*[m for m in metric_ftns],
                                writer=writer,
                                mode='val')
    val_metrics.reset()
    confusion_matrix = torch.zeros(args.class_dict, args.class_dict)
    with torch.no_grad():
        for batch_idx, input_tensors in enumerate(testloader):

            input_data, target = input_tensors
            if (args.cuda):
                input_data = input_data.cuda()
                target = target.cuda()

            output = model(input_data)

            loss = criterion(output, target)

            correct, total, acc = accuracy(output, target)
            num_samples = batch_idx * args.batch_size + 1
            _, pred = torch.max(output, 1)

            num_samples = batch_idx * args.batch_size + 1
            for t, p in zip(target.cpu().view(-1), pred.cpu().view(-1)):
                confusion_matrix[t.long(), p.long()] += 1
            val_metrics.update_all_metrics(
                {
                    'correct': correct,
                    'total': total,
                    'loss': loss.item(),
                    'accuracy': acc
                },
                writer_step=(epoch - 1) * len(testloader) + batch_idx)

    print_summary(args, epoch, num_samples, val_metrics, mode="Validation")
    s = sensitivity(confusion_matrix.numpy())
    ppv = positive_predictive_value(confusion_matrix.numpy())
    print(f" s {s} ,ppv {ppv}")
    val_metrics.update('sensitivity',
                       s,
                       writer_step=(epoch - 1) * len(testloader) + batch_idx)
    val_metrics.update('ppv',
                       ppv,
                       writer_step=(epoch - 1) * len(testloader) + batch_idx)
    print('Confusion Matrix\n{}'.format(confusion_matrix.cpu().numpy()))
    return val_metrics, confusion_matrix
Esempio n. 2
0
class Trainer(BaseTrainer):
	"""
	Trainer class
	"""
	def __init__(self, model, criterion, metric_ftns, optimizer, config, device,
				 data_loader, valid_data_loader=None, lr_scheduler=None, len_epoch=None):
		super().__init__(model, criterion, metric_ftns, optimizer, config)
		self.config = config
		self.device = device
		self.data_loader = data_loader
		if len_epoch is None:
			# epoch-based training
			self.len_epoch = len(self.data_loader)
		else:
			# iteration-based training
			self.data_loader = inf_loop(data_loader)
			self.len_epoch = len_epoch
		self.valid_data_loader = valid_data_loader
		self.do_validation = self.valid_data_loader is not None
		self.lr_scheduler = lr_scheduler
		self.log_step = int(np.sqrt(data_loader.batch_size))
		
		self.best_valid = 0
		self.train_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)
		self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

		self.file_metrics = str(self.checkpoint_dir / 'metrics.csv')
		


	def _train_epoch(self, epoch):
		"""
		Training logic for an epoch
		:param epoch: Integer, current training epoch.
		:return: A log that contains average loss and metric in this epoch.
		"""
		
		self.model.train()
		self.train_metrics.reset()
		for batch_idx, (data, target) in enumerate(self.data_loader):
			data, target = data.to(self.device), target.to(self.device)

			self.optimizer.zero_grad()
			output = self.model(data)
			loss = self.criterion(output, target)
			loss.backward()
			self.optimizer.step()

			self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
			self.train_metrics.update('loss', loss.item())
			for met in self.metric_ftns:
				self.train_metrics.update(met.__name__, met(output, target))

			if batch_idx % self.log_step == 0:
				"""self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
					epoch,
					self._progress(batch_idx),
					loss.item()))
				"""
				self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

			if batch_idx == self.len_epoch:
				break
		log = self.train_metrics.result()

		if self.do_validation:
			val_log = self._valid_epoch(epoch)
			log.update(**{'val_'+k : v for k, v in val_log.items()})

		if self.lr_scheduler is not None:
			self.lr_scheduler.step()


		output = self.model(data)
		loss = self.criterion(output, target)	

		self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
		self.writer.add_scalar('Loss',  loss)

		self._save_csv(epoch, log)

		return log

	def _valid_epoch(self, epoch):
		"""
		Validate after training an epoch
		:param epoch: Integer, current training epoch.
		:return: A log that contains information about validation
		"""
		self.model.eval()
		self.valid_metrics.reset()
		with torch.no_grad():
			for batch_idx, (data, target) in enumerate(self.valid_data_loader):
				data, target = data.to(self.device), target.to(self.device)

				output = self.model(data)
				loss = self.criterion(output, target)

				self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
				self.valid_metrics.update('loss', loss.item())
				for met in self.metric_ftns:
					self.valid_metrics.update(met.__name__, met(output, target))
				self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

		# add histogram of model parameters to the tensorboard
		for name, p in self.model.named_parameters():
			self.writer.add_histogram(name, p, bins='auto')

		# we added spùe custom here
		output = self.model(data)
		loss = self.criterion(output, target)
		self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
		self.writer.add_scalar('Loss',  loss)

		val_log = self.valid_metrics.result()
		actual_accu = val_log['accuracy']
		if(actual_accu - self.best_valid > 0.0025 and self.save):
			self.best_valid = actual_accu
			if self.tensorboard: # is true you can use tensorboard
				self._save_checkpoint(epoch, save_best=True)
			filename = str(self.checkpoint_dir / 'checkpoint-best-epoch.pth')
			torch.save(self.model.state_dict(), filename)
			self.logger.info("Saving checkpoint: {} ...".format(filename))

		return val_log

	def _progress(self, batch_idx):
		base = '[{}/{} ({:.0f}%)]'
		if hasattr(self.data_loader, 'n_samples'):
			current = batch_idx * self.data_loader.batch_size
			total = self.data_loader.n_samples
		else:
			current = batch_idx
			total = self.len_epoch
		return base.format(current, total, 100.0 * current / total)


	def _save_csv(self, epoch ,log):
		"""
			Saving checkpoints
			:param epoch: current epoch number
			:param log: logging information of the epoch
		"""

		fichier = open(self.file_metrics, "a")

		if epoch == 1:
			fichier.write("epoch,")
			for key in log:
				fichier.write(str(key) +",")
			fichier.write("\n")

		fichier.write(str(epoch) +",")
		for key in log:
			fichier.write(str(log[key]) + ",")
		fichier.write("\n")
		fichier.close()
Esempio n. 3
0
class Trainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 config,
                 model,
                 optimizer,
                 data_loader,
                 writer,
                 checkpoint_dir,
                 logger,
                 class_dict,
                 valid_data_loader=None,
                 test_data_loader=None,
                 lr_scheduler=None,
                 metric_ftns=None):
        super(Trainer, self).__init__(config,
                                      data_loader,
                                      writer,
                                      checkpoint_dir,
                                      logger,
                                      valid_data_loader=valid_data_loader,
                                      test_data_loader=test_data_loader,
                                      metric_ftns=metric_ftns)

        if (self.config.cuda):
            use_cuda = torch.cuda.is_available()
            self.device = torch.device("cuda" if use_cuda else "cpu")
        else:
            self.device = torch.device("cpu")
        self.start_epoch = 1
        self.train_data_loader = data_loader

        self.len_epoch = self.config.dataloader.train.batch_size * len(
            self.train_data_loader)
        self.epochs = self.config.epochs
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None
        self.lr_scheduler = lr_scheduler
        self.log_step = self.config.log_interval
        self.model = model
        self.num_classes = len(class_dict)
        self.optimizer = optimizer

        self.mnt_best = np.inf
        if self.config.dataset.type == 'multi_target':
            self.criterion = torch.nn.BCEWithLogitsLoss(reduction='mean')
        else:
            self.criterion = torch.nn.CrossEntropyLoss(reduction='mean')
        self.checkpoint_dir = checkpoint_dir
        self.gradient_accumulation = config.gradient_accumulation
        self.writer = writer
        self.metric_ftns = ['loss', 'acc']
        self.train_metrics = MetricTracker(*[m for m in self.metric_ftns],
                                           writer=self.writer,
                                           mode='train')
        self.metric_ftns = ['loss', 'acc']
        self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns],
                                           writer=self.writer,
                                           mode='validation')
        self.logger = logger

        self.confusion_matrix = torch.zeros(self.num_classes, self.num_classes)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        Args:
            epoch (int): current training epoch.
        """

        self.model.train()
        self.confusion_matrix = 0 * self.confusion_matrix
        self.train_metrics.reset()
        gradient_accumulation = self.gradient_accumulation
        for batch_idx, (data, target) in enumerate(self.train_data_loader):

            data = data.to(self.device)

            target = target.to(self.device)

            output = self.model(data)
            loss = self.criterion(output, target)
            loss = loss.mean()

            (loss / gradient_accumulation).backward()
            if (batch_idx % gradient_accumulation == 0):
                self.optimizer.step()  # Now we can do an optimizer step
                self.optimizer.zero_grad()  # Reset gradients tensors

            prediction = torch.max(output, 1)

            writer_step = (epoch - 1) * self.len_epoch + batch_idx

            self.train_metrics.update(key='loss',
                                      value=loss.item(),
                                      n=1,
                                      writer_step=writer_step)
            self.train_metrics.update(
                key='acc',
                value=np.sum(prediction[1].cpu().numpy() == target.squeeze(
                    -1).cpu().numpy()),
                n=target.size(0),
                writer_step=writer_step)
            for t, p in zip(target.cpu().view(-1),
                            prediction[1].cpu().view(-1)):
                self.confusion_matrix[t.long(), p.long()] += 1
            self._progress(batch_idx,
                           epoch,
                           metrics=self.train_metrics,
                           mode='train')

        self._progress(batch_idx,
                       epoch,
                       metrics=self.train_metrics,
                       mode='train',
                       print_summary=True)

    def _valid_epoch(self, epoch, mode, loader):
        """

        Args:
            epoch (int): current epoch
            mode (string): 'validation' or 'test'
            loader (dataloader):

        Returns: validation loss

        """
        self.model.eval()
        self.valid_sentences = []
        self.valid_metrics.reset()
        self.confusion_matrix = 0 * self.confusion_matrix
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(loader):
                data = data.to(self.device)

                target = target.to(self.device)

                output = self.model(data)
                loss = self.criterion(output, target)
                loss = loss.mean()
                writer_step = (epoch - 1) * len(loader) + batch_idx

                prediction = torch.max(output, 1)
                acc = np.sum(prediction[1].cpu().numpy() == target.squeeze(
                    -1).cpu().numpy()) / target.size(0)

                self.valid_metrics.update(key='loss',
                                          value=loss.item(),
                                          n=1,
                                          writer_step=writer_step)
                self.valid_metrics.update(
                    key='acc',
                    value=np.sum(prediction[1].cpu().numpy() == target.squeeze(
                        -1).cpu().numpy()),
                    n=target.size(0),
                    writer_step=writer_step)
                for t, p in zip(target.cpu().view(-1),
                                prediction[1].cpu().view(-1)):
                    self.confusion_matrix[t.long(), p.long()] += 1
        self._progress(batch_idx,
                       epoch,
                       metrics=self.valid_metrics,
                       mode=mode,
                       print_summary=True)

        s = sensitivity(self.confusion_matrix.numpy())
        ppv = positive_predictive_value(self.confusion_matrix.numpy())
        print(f" s {s} ,ppv {ppv}")
        val_loss = self.valid_metrics.avg('loss')

        return val_loss

    def train(self):
        """
        Train the model
        """
        for epoch in range(self.start_epoch, self.epochs):
            torch.manual_seed(self.config.seed)
            self._train_epoch(epoch)

            self.logger.info(f"{'!' * 10}    VALIDATION   , {'!' * 10}")
            validation_loss = self._valid_epoch(epoch, 'validation',
                                                self.valid_data_loader)
            make_dirs(self.checkpoint_dir)

            self.checkpointer(epoch, validation_loss)
            self.lr_scheduler.step(validation_loss)
            if self.do_test:
                self.logger.info(f"{'!' * 10}    VALIDATION   , {'!' * 10}")
                self.predict(epoch)

    def predict(self, epoch):
        """
        Inference
        Args:
            epoch ():

        Returns:

        """
        self.model.eval()

        predictions = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                data = data.to(self.device)

                logits = self.model(data, None)

                maxes, prediction = torch.max(
                    logits, 1)  # get the index of the max log-probability
                # log.info()
                predictions.append(
                    f"{target[0]},{prediction.cpu().numpy()[0]}")

        pred_name = os.path.join(
            self.checkpoint_dir,
            f'validation_predictions_epoch_{epoch:d}_.csv')
        write_csv(predictions, pred_name)
        return predictions

    def checkpointer(self, epoch, metric):

        is_best = metric < self.mnt_best
        if (is_best):
            self.mnt_best = metric

            self.logger.info(f"Best val loss {self.mnt_best} so far ")
            # else:
            #     self.gradient_accumulation = self.gradient_accumulation // 2
            #     if self.gradient_accumulation < 4:
            #         self.gradient_accumulation = 4

            save_model(self.checkpoint_dir, self.model, self.optimizer,
                       self.valid_metrics.avg('loss'), epoch, f'_model_best')
        save_model(self.checkpoint_dir, self.model, self.optimizer,
                   self.valid_metrics.avg('loss'), epoch, f'_model_last')

    def _progress(self,
                  batch_idx,
                  epoch,
                  metrics,
                  mode='',
                  print_summary=False):
        metrics_string = metrics.calc_all_metrics()
        if ((batch_idx * self.config.dataloader.train.batch_size) %
                self.log_step == 0):

            if metrics_string == None:
                self.logger.warning(f" No metrics")
            else:
                self.logger.info(
                    f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Sample [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}"
                )
        elif print_summary:
            self.logger.info(
                f'{mode} summary  Epoch: [{epoch}/{self.epochs}]\t {metrics_string}'
            )
Esempio n. 4
0
class MNISTTrainer(BaseTrainer):
    """
    Trainer class
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_fns,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None):
        super().__init__(model, criterion, metric_fns, optimizer, config,
                         device, data_loader, valid_data_loader, lr_scheduler)

        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch
        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)
        # Define evaluator
        self.evaluator = MNISTTester(self.model, self.criterion,
                                     self.metric_ftns, self.config,
                                     self.device, self.valid_data_loader, True)

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """

        self.model.train()
        self.train_metrics.reset()
        with tqdm(total=self.data_loader.n_samples) as progbar:
            for batch_idx, (data, target) in enumerate(self.data_loader):
                data, target = data.to(self.device), target.to(self.device)

                self.optimizer.zero_grad()
                output = self.model(data)
                loss = self.criterion(output, target)
                loss.backward()
                self.optimizer.step()

                self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
                self.train_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.train_metrics.update(met.__name__,
                                              met(output, target))

                if batch_idx % self.log_step == 0:
                    self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                        epoch, self._progress(batch_idx), loss.item()))
                    self.writer.add_image(
                        'input', make_grid(data.cpu(), nrow=8, normalize=True))

                if batch_idx == self.len_epoch:
                    break
                progbar.update(self.data_loader.init_kwargs['batch_size'])
                epoch_part = str(epoch) + '/' + str(self.epochs)
                progbar.set_postfix(epoch=epoch_part, NLL=loss.item())
        log = self.train_metrics.result()

        if self.do_validation:
            val_log = self.evaluator.test()
            log.update(**{'val_' + k: round(v, 5) for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
class Seq2SeqSimpleTrainer(BaseTrainer):
    """
    Trainer for a simple seq2seq mode.
    """
    def __init__(self,
                 model,
                 criterion,
                 train_metric_ftns,
                 eval_metric_fns,
                 optimizer,
                 config,
                 device,
                 data_loader,
                 valid_data_loader=None,
                 lr_scheduler=None,
                 len_epoch=None,
                 validate_only=False):
        """

        :param model: The model to train.
        :param criterion: we ignore this value and overwrite it
        :param train_metric_ftns: The metric function names to use for training.
        :param eval_metric_fns: The metric function names to use for evaluating.
        :param optimizer: The optimizer to use.
        :param config: The configuration file for the run.
        :param device: The device to train on.
        :param data_loader: The training data loader to use.
        :param valid_data_loader: The validation data loader to use.
        :param lr_scheduler: scheduler for the learning rate.
        :param len_epoch: The amount of examples in an epoch.
        :param validate_only: use if resumed, only run validation on the last resumed checkpoint.
        """
        self.vocab = model.vocab
        self.pad_idx = self.vocab['<pad>']

        self.criterion = criterion
        super().__init__(model, self.criterion, train_metric_ftns,
                         eval_metric_fns, optimizer, config, device,
                         data_loader, valid_data_loader, lr_scheduler)

        self.question_pad_length = config['data_loader']['question_pad_length']
        self.qdmr_pad_length = config['data_loader']['qdmr_pad_length']
        self.lexicon_pad_length = config['data_loader']['lexicon_pad_length']
        if len_epoch is None:
            # epoch-based training
            self.len_epoch = len(self.data_loader)
        else:
            # iteration-based training
            self.data_loader = inf_loop(data_loader)
            self.len_epoch = len_epoch

        self.do_validation = self.valid_data_loader is not None
        self.log_step = int(np.sqrt(data_loader.batch_size))

        self.train_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.train_metric_ftns],
            writer=self.writer)

        # Define evaluator.
        self.evaluator = Seq2SeqSimpleTester(self.model, self.criterion,
                                             self.eval_metric_ftns,
                                             self.config, self.device,
                                             self.valid_data_loader, True)

        # Run validation and exit.
        if validate_only:
            val_log = self.evaluator.test()
            log = {'val_' + k: round(v, 5) for k, v in val_log.items()}
            print(log)
            exit()

    def _train_epoch(self, epoch):
        """
        Training logic for an epoch.

        :param epoch: Integer, current training epoch.
        :return: A log that contains average loss and metric in this epoch.
        """
        # Sets the model to training mode.
        self.model.train()
        self.train_metrics.reset()
        convert_to_program = self.data_loader.gold_type_is_qdmr()

        with tqdm(total=len(self.data_loader)) as progbar:
            for batch_idx, (_, data, target,
                            lexicon_str) in enumerate(self.data_loader):
                data, mask_data = batch_to_tensor(self.vocab, data,
                                                  self.question_pad_length,
                                                  self.device)
                target, mask_target = batch_to_tensor(self.vocab, target,
                                                      self.qdmr_pad_length,
                                                      self.device)
                lexicon_ids, mask_lexicon = tokenize_lexicon_str(
                    self.vocab, lexicon_str, self.qdmr_pad_length, self.device)
                # Run the model on the batch
                self.optimizer.zero_grad()
                # out shape is (batch_size, seq_len, output_size)

                output, mask_output = self.model(data, target, lexicon_ids)

                # CEloss expects (minibatch, classes, seq_len)
                # out after transpose is (batch_size, output_size, seq_len)
                # output = torch.transpose(output, 1, 2)

                # Calculate the loss and perform optimization step.
                # TODO test properly use of masks
                # output dims should be (batch_size, num_decoding_steps, num_classes)
                loss = self.criterion(output, mask_output, target, mask_target)
                loss.backward()
                self.optimizer.step()

                self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)

                with torch.no_grad():
                    pred = torch.argmax(output, dim=1)
                    # data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False)
                    # target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program)
                    # pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program)

                # Update metrics
                self.train_metrics.update('loss', loss.item())
                # for met in self.metric_ftns:
                #     self.train_metrics.update(met.__name__, met(pred_str, target_str, data_str))

                # Log progress
                if batch_idx % self.log_step == 0:
                    self.logger.debug('Train Epoch: {} {} Loss: {:.6f}'.format(
                        epoch, self._progress(batch_idx), loss.item()))
                    # TODO set this to write the text examples or remove
                    # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))

                if batch_idx == self.len_epoch:
                    break

                # Update the progress bar.
                progbar.update(1)
                epoch_part = str(epoch) + '/' + str(self.epochs)
                progbar.set_postfix(
                    epoch=epoch_part,
                    LOSS=loss.item(),
                    batch_size=self.data_loader.init_kwargs['batch_size'],
                    samples=self.data_loader.n_samples)

        # Save the calculated metrics for that epoch.
        log = self.train_metrics.result()

        # If validation split exists, evaluate on validation set as well.
        if self.do_validation:
            # TODO print epoch stuff and add epoch to writer
            # TODO self.writer.set_step((epoch - 1) * len(self.valid_data_loader) + batch_idx, 'valid')
            val_log = self.evaluator.test()
            log.update(**{'val_' + k: round(v, 5) for k, v in val_log.items()})

        if self.lr_scheduler is not None:
            self.lr_scheduler.step()

        return log

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        if hasattr(self.data_loader, 'n_samples'):
            current = batch_idx * self.data_loader.batch_size
            total = self.data_loader.n_samples
        else:
            current = batch_idx
            total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)
Esempio n. 6
0
class Trainer():
    def __init__(self,
                 model,
                 criterion,
                 metrics_name,
                 optimizer,
                 train_loader,
                 logger,
                 log_dir,
                 nb_epochs,
                 save_dir,
                 device="cuda:0",
                 log_step=10,
                 start_epoch=0,
                 enable_tensorboard=True,
                 valid_loader=None,
                 lr_scheduler=None,
                 monitor="min val_loss",
                 early_stop=10,
                 save_epoch_period=1,
                 resume=""):
        self.model = model
        self.criterion = criterion
        self.metrics_name = metrics_name
        self.optimizer = optimizer
        self.train_loader = train_loader
        self.valid_loader = valid_loader

        self.len_epoch = len(self.train_loader)
        self.do_validation = (self.valid_loader is not None)
        self.lr_scheduler = lr_scheduler
        self.log_step = log_step
        self.epochs = nb_epochs
        self.start_epoch = start_epoch + 1

        self.logger = logger
        self.device = device
        self.save_period = save_epoch_period

        self.writer = TensorboardWriter(log_dir, self.logger,
                                        enable_tensorboard)
        self.train_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.valid_metrics = MetricTracker('loss',
                                           *self.metrics_name,
                                           writer=self.writer)
        self.checkpoint_dir = save_dir
        if monitor == 'off':
            self.mnt_mode = 'off'
            self.mnt_best = 0
        else:
            self.mnt_mode, self.mnt_metric = monitor.split()
            assert self.mnt_mode in ['min', 'max']
            self.mnt_best = inf if self.mnt_mode == 'min' else -inf
            self.early_stop = early_stop
        if resume != "":
            self._resume_checkpoint(resume_path=resume)
        self.model.to(self.device)

    def train(self):
        not_improved_count = 0

        for epoch in range(self.start_epoch, self.epochs + 1):
            result = self._train_epoch(epoch)
            log = {'epoch': epoch}
            log.update(result)
            self.logger.info('    {:15s}: {}'.format(str("mnt best"),
                                                     self.mnt_best))
            for key, value in log.items():
                self.logger.info('    {:15s}: {}'.format(str(key), value))
            best = False
            if self.mnt_mode != 'off':
                try:
                    # check whether model performance improved or not, according to specified metric(mnt_metric)
                    improved = (self.mnt_mode == 'min' and log[self.mnt_metric] < self.mnt_best) or \
                               (self.mnt_mode == 'max' and log[self.mnt_metric] > self.mnt_best)
                except KeyError:
                    self.logger.warning(
                        "Warning: Metric '{}' is not found. "
                        "Model performance monitoring is disabled.".format(
                            self.mnt_metric))
                    self.mnt_mode = 'off'
                    improved = False
                if improved:
                    self.mnt_best = log[self.mnt_metric]
                    not_improved_count = 0
                    best = True
                else:
                    not_improved_count += 1
                if (not_improved_count > self.early_stop) and (self.early_stop
                                                               > 0):
                    self.logger.info(
                        "Validation performance didn\'t improve for {} epochs. "
                        "Training stops.".format(self.early_stop))
                    break

            if epoch % self.save_period == 0:
                self._save_checkpoint(epoch, best)

    def _train_epoch(self, epoch):
        self.model.train()
        self.train_metrics.reset()
        start_time = time.time()

        for batch_idx, sample in enumerate(self.train_loader):
            data = sample['image']
            target = sample['mask']
            data, target = data.to(self.device), target.to(self.device)
            current_lr = self.lr_scheduler(self.optimizer, batch_idx, epoch)
            self.optimizer.zero_grad()
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.optimizer.step()
            self.writer.set_step((epoch - 1) * self.len_epoch + batch_idx)
            self.train_metrics.update('loss', loss.item())
            for met_name in self.metrics_name:
                self.train_metrics.update(
                    met_name,
                    getattr(metrics, met_name)(output, target))
            if batch_idx % self.log_step == 0:
                time_to_run = time.time() - start_time
                start_time = time.time()
                speed = self.log_step / time_to_run
                self.logger.debug('Train Epoch: {} {} Loss: {:.6f} LR: {:.6f}  Speed: {:.4f}iters/s' \
                                  .format(epoch, self._progress(batch_idx), loss.item(), current_lr, speed))
                for met_name in self.metrics_name:
                    self.writer.add_scalar(met_name,
                                           self.train_metrics.avg(met_name))
                self.writer.add_scalar('loss', self.train_metrics.avg('loss'))
                self.writer.add_scalar("lr", current_lr)
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
            assert batch_idx <= self.len_epoch
        log = self.train_metrics.result()
        if self.do_validation:
            print("Start validation")
            val_log, iou_classes = self._valid_epoch(epoch)

            log.update(**{'val_' + k: v for k, v in val_log.items()})
            for key, value in iou_classes.items():
                log.update({key: value})
        return log

    def _valid_epoch(self, epoch):
        self.model.eval()
        self.valid_metrics.reset()
        iou_tracker = metrics.IoU(2)
        with torch.no_grad():
            for batch_idx, sample in enumerate(self.valid_loader):
                data = sample['image']
                target = sample['mask']
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.writer.set_step(
                    (epoch - 1) * len(self.valid_loader) + batch_idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                # self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
                target = target.cpu().numpy()
                output = output[:, 0]
                output = output.data.cpu().numpy()
                pred = np.zeros_like(output)
                pred[output > 0.5] = 1
                pred = pred.astype(np.int64)
                for i in range(len(target)):
                    iou_tracker.add_batch(target[i], pred[i])
        iou_classes = iou_tracker.get_iou()
        for key, value in iou_classes.items():
            self.writer.add_scalar(key, value)
        self.writer.add_scalar('val_loss', self.valid_metrics.avg('loss'))

        for met_name in self.metrics_name:
            self.writer.add_scalar(met_name, self.valid_metrics.avg(met_name))

        # for name, p in self.model.named_parameters():
        #     print(name, p)
        #     self.writer.add_histogram(name, p.cpu().data.numpy(), bins='auto')
        #
        return self.valid_metrics.result(), iou_classes

    def _progress(self, batch_idx):
        base = '[{}/{} ({:.0f}%)]'
        current = batch_idx
        total = self.len_epoch
        return base.format(current, total, 100.0 * current / total)

    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints

        :param epoch: current epoch number
        :param log: logging information of the epoch
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.model.state_dict(),
            'optimizer': self.optimizer.state_dict(),
            'monitor_best': self.mnt_best,
            # 'config': self.config
        }
        filename = str(self.checkpoint_dir /
                       'checkpoint-epoch{:06d}.pth'.format(epoch))
        torch.save(state, filename)
        self.delete_checkpoint()
        self.logger.info("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = str(self.checkpoint_dir / 'model_best.pth')
            torch.save(state, best_path)
            self.logger.info("Saving current best: model_best.pth ...")

    def delete_checkpoint(self):
        checkpoints_file = list(
            self.checkpoint_dir.glob("checkpoint-epoch*.pth"))
        checkpoints_file.sort()
        for checkpoint_file in checkpoints_file[:-5]:
            os.remove(str(checkpoint_file.absolute()))

    def _resume_checkpoint(self, resume_path):
        self.logger.info("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path)
        self.start_epoch = checkpoint['epoch'] + 1
        self.mnt_best = checkpoint['monitor_best']

        self.model.load_state_dict(checkpoint['state_dict'])
        self.optimizer.load_state_dict(checkpoint['optimizer'])

        self.logger.info(
            "Checkpoint loaded. Resume training from epoch {}".format(
                self.start_epoch))
Esempio n. 7
0
def main(config, use_transformers):
    logger = config.get_logger('test')

    # 1
    # device = torch.device('cpu')
    # 2
    device = torch.device('cuda:{}'.format(config.config['device_id'])
                          if config.config['n_gpu'] > 0 else 'cpu')

    # 测试集语料
    test_dataset = config.init_obj('test_dataset',
                                   module_data_process,
                                   device=device)

    if use_transformers:
        test_dataloader = config.init_obj(
            'test_data_loader',
            module_dataloader,
            dataset=test_dataset.data_set,
            collate_fn=test_dataset.bert_collate_fn_4_inference)

        model = config.init_obj('model_arch', module_arch, word_embedding=None)

    else:
        # 原始语料,只需要dataset,不需要dataloader,拿到dataset.word_embedding,普通神经网网络才需要
        dataset = config.init_obj('dataset',
                                  module_data_process,
                                  device=device)

        test_dataloader = config.init_obj(
            'test_data_loader',
            module_dataloader,
            dataset=test_dataset.data_set,
            collate_fn=test_dataset.collate_fn_4_inference)

        model = config.init_obj('model_arch',
                                module_arch,
                                word_embedding=dataset.word_embedding)

    if config['n_gpu'] > 1:
        device_ids = list(
            map(lambda x: int(x), config.config['device_id'].split(',')))
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    logger.info('Loading checkpoint: {} ...'.format(config.resume))
    # checkpoint = torch.load(pathlib2.PureWindowsPath(config.resume))
    # checkpoint = torch.load(config.resume.replace('\\', '/'))
    # checkpoint = torch.load("\\saved\\text_cnn_1d\\models\\0706_122111\\checkpoint-epoch15.pth")
    # checkpoint = torch.load(pathlib2.PureWindowsPath(str(config.resume)))
    # checkpoint = torch.load(pathlib.PurePath(config.resume))
    # checkpoint = torch.load(pathlib.PureWindowsPath(config.resume))
    # checkpoint = torch.load(str(pathlib.PureWindowsPath(config.resume)))
    # checkpoint = torch.load(pathlib.PureWindowsPath(os.path.join(str(config.resume))))
    # checkpoint = torch.load(os.path.join(str(config.resume)))
    # checkpoint = torch.load(open(os.path.join(str(config.resume)), 'rb'))
    # checkpoint = torch.load(open(pathlib.joinpath(str(config.resume)), 'rb'))
    checkpoint = torch.load(config.resume)

    state_dict = checkpoint['state_dict']
    model.load_state_dict(state_dict)

    # 2
    model = model.cuda()
    model.eval()

    metric_ftns = [getattr(module_metric, met) for met in config['metrics']]
    test_metrics = MetricTracker(*[m.__name__ for m in metric_ftns])

    with torch.no_grad():
        for i, batch_data in enumerate(test_dataloader):
            # 一个batch,128条评论
            input_token_ids, _, seq_lens, class_labels, texts = batch_data

            # 输出值
            output = model(input_token_ids, _, seq_lens).squeeze(1)
            # 真实类别
            class_labels = class_labels

            # bert时候,到时候再写个布尔吧,这样不再多做一点处理(6222%128=78个结尾不去算了)
            if (i + 1) % 8 == 1:
                output_one = output.clone()
                class_labels_one = class_labels.clone()
            elif (i + 1) % 8 == 2:
                output_two = output.clone()
                class_labels_two = class_labels.clone()
            elif (i + 1) % 8 == 3:
                output_three = output.clone()
                class_labels_three = class_labels.clone()
            elif (i + 1) % 8 == 4:
                output_four = output.clone()
                class_labels_four = class_labels.clone()
            elif (i + 1) % 8 == 5:
                output_five = output.clone()
                class_labels_five = class_labels.clone()
            elif (i + 1) % 8 == 6:
                output_six = output.clone()
                class_labels_six = class_labels.clone()
            elif (i + 1) % 8 == 7:
                output_seven = output.clone()
                class_labels_seven = class_labels.clone()
            else:
                pred_tensor = torch.cat(
                    (output_one, output_two, output_three, output_four,
                     output_five, output_six, output_seven, output), 0)
                label_tensor = torch.cat(
                    (class_labels_one, class_labels_two, class_labels_three,
                     class_labels_four, class_labels_five, class_labels_six,
                     class_labels_seven, class_labels), 0)
                for met in metric_ftns:
                    test_metrics.update(met.__name__,
                                        met(pred_tensor, label_tensor))

            # # 普通时候
            # for met in metric_ftns:
            #     test_metrics.update(met.__name__, met(output, class_labels))

    test_log = test_metrics.result()

    for k, v in test_log.items():
        logger.info('    {:25s}: {}'.format(str(k), v))

    print(test_log['binary_auc'])

    return test_log['binary_auc']
Esempio n. 8
0
class Tester(BaseTrainer):
    """
    Trainer class
    """

    def __init__(self, config, model, data_loader, writer, checkpoint_dir, logger,
                 valid_data_loader=None, test_data_loader=None, metric_ftns=None):
        super(Tester, self).__init__(config, data_loader, writer, checkpoint_dir, logger,
                                     valid_data_loader=valid_data_loader,
                                     test_data_loader=test_data_loader, metric_ftns=metric_ftns)

        if (self.config.cuda):
            use_cuda = torch.cuda.is_available()
            self.device = torch.device("cuda" if use_cuda else "cpu")
        else:
            self.device = torch.device("cpu")
        self.start_epoch = 1

        self.epochs = self.config.epochs
        self.valid_data_loader = valid_data_loader
        self.test_data_loader = test_data_loader
        self.do_validation = self.valid_data_loader is not None
        self.do_test = self.test_data_loader is not None

        self.log_step = self.config.log_interval
        self.model = model

        self.mnt_best = np.inf

        self.checkpoint_dir = checkpoint_dir
        self.gradient_accumulation = config.gradient_accumulation
        self.metric_ftns = ['loss', 'acc']
        self.valid_metrics = MetricTracker(*[m for m in self.metric_ftns], writer=self.writer, mode='validation')
        self.logger = logger
    def _valid_epoch(self, epoch, mode, loader):
        """

        Args:
            epoch (int): current epoch
            mode (string): 'validation' or 'test'
            loader (dataloader):

        Returns: validation loss

        """
        self.model.eval()
        self.valid_sentences = []
        self.valid_metrics.reset()

        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(loader):
                data = data.to(self.device)

                target = target.long().to(self.device)

                output, loss = self.model(data, target)
                loss = loss.mean()
                writer_step = (epoch - 1) * len(loader) + batch_idx

                prediction = torch.max(output, 1)
                acc = np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()) / target.size(0)

                self.valid_metrics.update(key='loss',value=loss.item(),n=1,writer_step=writer_step)
                self.valid_metrics.update(key='acc', value=np.sum(prediction[1].cpu().numpy() == target.cpu().numpy()), n=target.size(0), writer_step=writer_step)

        self._progress(batch_idx, epoch, metrics=self.valid_metrics, mode=mode, print_summary=True)


        val_loss = self.valid_metrics.avg('loss')


        return val_loss


    def predict(self):
        """
        Inference
        Args:
            epoch ():

        Returns:

        """
        self.model.eval()

        predictions = []
        with torch.no_grad():
            for batch_idx, (data, target) in enumerate(self.test_data_loader):
                data = data.to(self.device)

                logits = self.model(data, None)

                maxes, prediction = torch.max(logits, 1)  # get the index of the max log-probability

                predictions.append(f"{target[0]},{prediction.cpu().numpy()[0]}")
        self.logger.info('Inference done')
        pred_name = os.path.join(self.checkpoint_dir, f'predictions.csv')
        write_csv(predictions, pred_name)
        return predictions

    def _progress(self, batch_idx, epoch, metrics, mode='', print_summary=False):
        metrics_string = metrics.calc_all_metrics()
        if ((batch_idx * self.config.dataloader.train.batch_size) % self.log_step == 0):

            if metrics_string == None:
                self.logger.warning(f" No metrics")
            else:
                self.logger.info(
                    f"{mode} Epoch: [{epoch:2d}/{self.epochs:2d}]\t Video [{batch_idx * self.config.dataloader.train.batch_size:5d}/{self.len_epoch:5d}]\t {metrics_string}")
        elif print_summary:
            self.logger.info(
                f'{mode} summary  Epoch: [{epoch}/{self.epochs}]\t {metrics_string}')
class MNISTTester(BaseTester):
    """
    Trainer for a simple seq2seq mode.
    """
    def __init__(self,
                 model,
                 criterion,
                 metric_fns,
                 config,
                 device,
                 data_loader,
                 evaluation=True):
        """

        :param model:
        :param criterion: we ignore this value and overwrite it
        :param metric_fns:
        :param optimizer:
        :param config:
        :param device:
        :param data_loader:
        :param valid_data_loader:
        :param lr_scheduler:
        :param len_epoch:
        """

        super().__init__(model, criterion, metric_fns, config, device,
                         data_loader, evaluation)

        self.valid_metrics = MetricTracker(
            'loss',
            *[m.__name__ for m in self.metric_ftns],
            writer=self.writer)

    def _evaluate(self):
        """
        Validate after training an epoch.
        Used with  gold target

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        # Sets the model to evaluation mode.
        self.valid_metrics.reset()
        total_loss = 0.0
        total_metrics = torch.zeros(len(self.metric_ftns))

        for i, (data, target) in enumerate(tqdm(self.data_loader)):
            data, target = data.to(self.device), target.to(self.device)
            output = self.model(data)
            # computing loss, metrics on test set
            loss = self.criterion(output, target)
            batch_size = data.shape[0]
            total_loss += loss.item() * batch_size
            for i, metric in enumerate(self.metric_ftns):
                total_metrics[i] += metric(output, target) * batch_size

            self.valid_metrics.update('loss', loss.item())
            for met in self.metric_ftns:
                self.valid_metrics.update(met.__name__, met(output, target))
            self.writer.add_image(
                'input', make_grid(data.cpu(), nrow=8, normalize=True))

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()

    def _predict_without_target(self):
        return self._evaluate()
Esempio n. 10
0
def main(config):
    # logger = config.get_logger('test')

    # setup data_loader instances
    data_loader = getattr(module_data, config['data_loader']['type'])(
        mode="test",
        data_root="/root/userfolder/Dataset/ImagesAnnotations_aug/",
        fold=0,
        num_workers=4,
        batch_size=96)

    # build model architecture
    model = config.init_obj('arch', module_arch)
    params = compute_params(model)
    print(model)
    print('the params of model is: ', params)
    # logger.info(model)

    # get function handles of loss and metrics
    # loss_fn = getattr(module_loss, config['loss'])
    loss_fn = nn.BCEWithLogitsLoss()
    metric_fns = [getattr(module_metric, met) for met in config['metrics']]

    resume_path = os.path.join(config['project_root'],
                               config['trainer']['resume_path'])
    checkpoint = torch.load(resume_path, map_location=torch.device('cpu'))
    # logger.info('Loading checkpoint: {} ...'.format(resume_path))
    print('Loading checkpoint: {} ...'.format(resume_path))
    state_dict = checkpoint['state_dict']
    gpus = config['gpu_device']
    if len(gpus) > 1:
        model = torch.nn.DataParallel(model)
    model.load_state_dict(state_dict)

    # prepare model for testing
    # device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    device = torch.device(
        'cuda:{}'.format(gpus[0]) if torch.cuda.is_available() else 'cpu')
    model = model.to(device)
    model.eval()

    total_loss = 0.0
    total_metrics = torch.zeros(len(metric_fns))
    outputs = []
    targets = []
    test_metrics = MetricTracker('loss',
                                 'time',
                                 *[m.__name__ for m in metric_fns],
                                 writer=None)
    image_shape = None
    f_dir, f_name = os.path.split(resume_path)
    csv_path = os.path.join(f_dir, 'prediction.csv')
    f = open(csv_path, 'w')
    csv_writer = csv.writer(f)
    keys = ['label', 'pred']
    values = []
    csv_writer.writerow(keys)
    with torch.no_grad():
        for i, (data, target) in enumerate(tqdm(data_loader.test_loader)):
            data, target = data.to(device), target.to(device).float()
            # data, target = data.cuda(), target.cuda().float()
            image_shape = [data.shape[2], data.shape[3]]
            torch.cuda.synchronize(device)
            start = time.time()
            # with torch.autograd.profiler.profile(use_cuda=True) as prof:
            output = model(data)
            torch.cuda.synchronize(device)
            end = time.time()
            # print('time:',end-start)
            pred = output.clone()  # [batch, c]
            pred_list = torch.sigmoid(pred).squeeze().tolist()
            label = target.clone()  # [batch]
            label_list = label.squeeze().tolist()
            _ = [
                values.append([label_list[index], pred_list[index]])
                for index in range(len(pred_list))
            ]

            output = output.unsqueeze(dim=2).unsqueeze(dim=3)
            target = target.unsqueeze(dim=2)
            outputs.append(output.clone())
            targets.append(target.clone())
            loss = loss_fn(output.squeeze(dim=1), target)
            total_loss += loss.item()
            test_metrics.update('time', end - start)
            # for i, metric in enumerate(metric_fns):
            #     total_metrics[i] += metric(output, target, apply_nonlin=True)
            # print(prof)
    csv_writer.writerows(values)
    f.close()
    outputs = torch.cat(outputs, dim=0)  # [steps*batch, 1, 1, 1]
    targets = torch.cat(targets, dim=0)

    for met in metric_fns:
        test_metrics.update(met.__name__, met(outputs, targets))
    log = test_metrics.result()

    print(log)
    # summary(model, (1,496, 384))
    time_results = compute_precise_time(model, [496, 384], 96, loss_fn, device)
    print(time_results)
    reset_bn_stats(model)
    return
Esempio n. 11
0
class Seq2SeqSimpleTester(BaseTester):
    """
    Trainer for a simple seq2seq mode.
    """

    def __init__(self, model, criterion, metric_ftns, config, device,
                 data_loader, evaluation=True):
        """
        :param model: A model to test.
        :param criterion: we ignore this value and overwrite it
        :param metric_ftns: The names of the metric functions to use.
        :param config: The configuration.
        :param device: The device to use for the testing.
        :param data_loader: The dataloader to use for loading the testing data.
        """

        self.vocab = model.vocab
        self.question_pad_length = config['data_loader']['question_pad_length']
        self.qdmr_pad_length = config['data_loader']['qdmr_pad_length']
        self.lexicon_pad_length = config['data_loader']['lexicon_pad_length']
        self.pad_idx = self.vocab['<pad>']

        # Overriding the criterion.
        # self.criterion = CrossEntropyLoss(ignore_index=self.pad_idx)
        self.criterion = criterion
        super().__init__(model, self.criterion, metric_ftns, config, device, data_loader, evaluation)

        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer=self.writer)

    def _evaluate(self):
        """
        Validate after training an epoch.
        Used with  gold target

        :param epoch: Integer, current training epoch.
        :return: A log that contains information about validation
        """
        # Choose 2 random examples from the dev set and print their prediction.
        batch_index1 = random.randint(0, len(self.data_loader) - 1) - 1
        example_index1 = random.randint(0, self.data_loader.batch_size - 1)
        batch_index2 = random.randint(0, len(self.data_loader) - 1) - 1
        example_index2 = random.randint(0, self.data_loader.batch_size - 1)
        questions = []
        decompositions = []
        targets = []
        convert_to_program = self.data_loader.gold_type_is_qdmr()

        # Sets the model to evaluation mode.
        self.valid_metrics.reset()
        with tqdm(total=len(self.data_loader)) as progbar:
            for batch_idx, (_, data, target, lexicon_str) in enumerate(self.data_loader):
                data, mask_data = batch_to_tensor(self.vocab, data, self.question_pad_length, self.device)
                target, mask_target = batch_to_tensor(self.vocab, target, self.qdmr_pad_length, self.device)
                lexicon_ids, mask_lexicon = tokenize_lexicon_str(self.vocab, lexicon_str, self.qdmr_pad_length, self.device)
                start = time.time()
                # Run the model on the batch and calculate the loss
                output, mask_output = self.model(data, target, lexicon_ids, evaluation_mode=True)
                loss = self.criterion(output, mask_output, target, mask_target)
                output = torch.transpose(output, 1, 2)
                pred = torch.argmax(output, dim=1)

                start = time.time()
                # Convert the predictions/ targets/questions from tensor of token_ids to list of strings.
                # TODO do we need to convert here or can we use the originals? (for data and target)
                data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False)
                target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program)
                pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program)

                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(pred_str, target_str, data_str))

                # Print example for predictions.
                if batch_idx == batch_index1:
                    questions.append(data_str[example_index1])
                    decompositions.append(pred_str[example_index1])
                    targets.append(target_str[example_index1])

                if batch_idx == batch_index2:
                    questions.append(data_str[example_index2])
                    decompositions.append(pred_str[example_index2])
                    targets.append(target_str[example_index2])

                # Update the progress bar.
                progbar.update(1)
                progbar.set_postfix(LOSS=loss.item(),
                                    batch_size=self.data_loader.init_kwargs['batch_size'],
                                    samples=self.data_loader.n_samples)

        # Print example predictions.
        for question, decomposition, target in zip(questions, decompositions, targets):
            print('\ndecomposition example:')
            print('question:\t\t', question)
            print('decomposition:\t', decomposition)
            print('target:\t\t\t', target)
            print()

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()


    def _predict_without_target(self):
        """
        get model predictions for testing.
        Used without targets

        :return: A log that contains information about predictions
        """
        qid_col = []
        pred_col = []
        question_col = []

        convert_to_program = self.data_loader.gold_type_is_qdmr()

        # Sets the model to evaluation mode.
        self.valid_metrics.reset()
        with tqdm(total=len(self.data_loader)) as progbar:
            for batch_idx, (question_ids, data, target, lexicon_str) in enumerate(self.data_loader):
                data, mask_data = batch_to_tensor(self.vocab, data, self.question_pad_length, self.device)
                target, mask_target = batch_to_tensor(self.vocab, target, self.qdmr_pad_length, self.device)
                lexicon_ids, mask_lexicon = tokenize_lexicon_str(self.vocab, lexicon_str, self.qdmr_pad_length, self.device)
                start = time.time()
                # Run the model on the batch and calculate the loss
                output, mask_output = self.model(data, target, lexicon_ids, evaluation_mode=True)
                loss = self.criterion(output, mask_output, target, mask_target)
                output = torch.transpose(output, 1, 2)
                pred = torch.argmax(output, dim=1)
                start = time.time()
                # Convert the predictions/ targets/questions from tensor of token_ids to list of strings.
                # TODO do we need to convert here or can we use the originals? (for data and target)

                data_str = batch_to_str(self.vocab, data, mask_data, convert_to_program=False)
                target_str = batch_to_str(self.vocab, target, mask_target, convert_to_program=convert_to_program)
                pred_str = pred_batch_to_str(self.vocab, pred, convert_to_program=convert_to_program)

                for i, question_id in enumerate(question_ids):
                    self.logger.info('{}:{}'.format(question_id, data_str[i]))
                qid_col.extend(question_ids)
                pred_col.extend(pred_str)
                question_col.extend(data_str)

                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(pred_str, target_str, data_str))

                # Update the progress bar.
                progbar.update(1)
                progbar.set_postfix(LOSS=loss.item(),
                                    batch_size=self.data_loader.init_kwargs['batch_size'],
                                    samples=self.data_loader.n_samples)
        d = {'question_id': qid_col, 'question_text': question_col, 'decomposition': pred_col}
        programs_df = pd.DataFrame(data=d)
        programs_df.to_csv(self.predictions_file_name, index=False, encoding='utf-8')

        # add histogram of model parameters to the tensorboard
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        return self.valid_metrics.result()