Ejemplo n.º 1
0
class Learning(object):
    def __init__(self,
            model,
            criterion,
            optimizer,
            scheduler,
            metric_ftns,
            device,
            num_epoch,
            grad_clipping,
            grad_accumulation_steps,
            early_stopping,
            validation_frequency,
            tensorboard,
            checkpoint_dir,
            resume_path):
        self.device, device_ids = self._prepare_device(device)
        # self.model = model.to(self.device)
        
        self.start_epoch = 1
        if resume_path is not None:
            self._resume_checkpoint(resume_path)
        if len(device_ids) > 1:
            # self.model = torch.nn.DataParallel(model, device_ids=device_ids)
            self.model = torch.nn.DataParallel(model)
            # cudnn.benchmark = True
        self.model = model.cuda()
        self.criterion = criterion
        self.metric_ftns = metric_ftns
        self.optimizer = optimizer
        self.num_epoch = num_epoch 
        self.scheduler = scheduler
        self.grad_clipping = grad_clipping
        self.grad_accumulation_steps = grad_accumulation_steps
        self.early_stopping = early_stopping
        self.validation_frequency =validation_frequency
        self.checkpoint_dir = checkpoint_dir
        self.best_epoch = 1
        self.best_score = 0
        self.writer = TensorboardWriter(os.path.join(checkpoint_dir, 'tensorboard'), tensorboard)
        self.train_metrics = MetricTracker('loss', writer = self.writer)
        self.valid_metrics = MetricTracker('loss', *[m.__name__ for m in self.metric_ftns], writer = self.writer)
        
    def train(self, train_dataloader):
        score = 0
        for epoch in range(self.start_epoch, self.num_epoch+1):
            print("{} epoch: \t start training....".format(epoch))
            start = time.time()
            train_result  = self._train_epoch(epoch, train_dataloader)
            train_result.update({'time': time.time()-start})
            
            for key, value in train_result.items():
                print('    {:15s}: {}'.format(str(key), value))

            # if (epoch+1) % self.validation_frequency!=0:
            #     print("skip validation....")
            #     continue
            # print('{} epoch: \t start validation....'.format(epoch))
            # start = time.time()
            # valid_result = self._valid_epoch(epoch, valid_dataloader)
            # valid_result.update({'time': time.time() - start})
            
            # for key, value in valid_result.items():
            #     if 'score' in key:
            #         score = value 
            #     print('   {:15s}: {}'.format(str(key), value))
            score+=0.001
            self.post_processing(score, epoch)
            if epoch - self.best_epoch > self.early_stopping:
                print('WARNING: EARLY STOPPING')
                break
    def _train_epoch(self, epoch, data_loader):
        self.model.train()
        self.optimizer.zero_grad()
        self.train_metrics.reset()
        for idx, (data, target) in enumerate(data_loader):
            data = Variable(data.cuda())
            target = [ann.to(self.device) for ann in target]
            output = self.model(data)
            loss = self.criterion(output, target)
            loss.backward()
            self.writer.set_step((epoch - 1) * len(data_loader) + idx)
            self.train_metrics.update('loss', loss.item())
            if (idx+1) % self.grad_accumulation_steps == 0:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_clipping)
                self.optimizer.step()
                self.optimizer.zero_grad()
            if (idx+1) % int(np.sqrt(len(data_loader))) == 0:
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
        return self.train_metrics.result()
    def _valid_epoch(self, epoch, data_loader):
        self.valid_metrics.reset()
        self.model.eval()
        with torch.no_grad():
            for idx, (data, target) in enumerate(data_loader):
                data, target = data.to(self.device), target.to(self.device)
                output = self.model(data)
                loss = self.criterion(output, target)
                self.writer.set_step((epoch - 1) * len(data_loader) + idx, 'valid')
                self.valid_metrics.update('loss', loss.item())
                for met in self.metric_ftns:
                    self.valid_metrics.update(met.__name__, met(output, target))
                self.writer.add_image('input', make_grid(data.cpu(), nrow=8, normalize=True))
        
        for name, p in self.model.named_parameters():
            self.writer.add_histogram(name, p, bins='auto')
        
        return self.valid_metrics.result()
    def post_processing(self, score, epoch):
        best = False
        if score > self.best_score:
            self.best_score = score 
            self.best_epoch = epoch 
            best = True
            print("best model: {} epoch - {:.5}".format(epoch, score))
        self._save_checkpoint(epoch = epoch, save_best = best)
        
        if self.scheduler.__class__.__name__ == 'ReduceLROnPlateau':
            self.scheduler.step(score)
        else:
            self.scheduler.step()
    
    def _save_checkpoint(self, epoch, save_best=False):
        """
        Saving checkpoints
        :param epoch: current epoch number
        :param save_best: if True, rename the saved checkpoint to 'model_best.pth'
        """
        arch = type(self.model).__name__
        state = {
            'arch': arch,
            'epoch': epoch,
            'state_dict': self.get_state_dict(self.model),
            'best_score': self.best_score
        }
        filename = os.path.join(self.checkpoint_dir, 'checkpoint_epoch{}.pth'.format(epoch))
        torch.save(state, filename)
        print("Saving checkpoint: {} ...".format(filename))
        if save_best:
            best_path = os.path.join(self.checkpoint_dir, 'model_best.pth')
            torch.save(state, best_path)
            print("Saving current best: model_best.pth ...")
    @staticmethod
    def get_state_dict(model):
        if type(model) == torch.nn.DataParallel:
            state_dict = model.module.state_dict()
        else:
            state_dict = model.state_dict()
        return state_dict
    
    def _resume_checkpoint(self, resume_path):
        resume_path = str(resume_path)
        print("Loading checkpoint: {} ...".format(resume_path))
        checkpoint = torch.load(resume_path, map_location=lambda storage, loc: storage)
        self.start_epoch = checkpoint['epoch'] + 1
        self.best_epoch = checkpoint['epoch']
        self.best_score = checkpoint['best_score']
        self.model.load_state_dict(checkpoint['state_dict'])

        print("Checkpoint loaded. Resume training from epoch {}".format(self.start_epoch))
    
    @staticmethod
    def _prepare_device(device):
        n_gpu_use = len(device)
        n_gpu = torch.cuda.device_count()
        if n_gpu_use > 0 and n_gpu == 0:
            print("Warning: There\'s no GPU available on this machine, training will be performed on CPU.")
            n_gpu_use = 0
        if n_gpu_use > n_gpu:
            print("Warning: The number of GPU\'s configured to use is {}, but only {} are available on this machine.".format(n_gpu_use, n_gpu))
            n_gpu_use = n_gpu
        list_ids = device
        device = torch.device('cuda:{}'.format(device[0]) if n_gpu_use > 0 else 'cpu')
        
        return device, list_ids
Ejemplo n.º 2
0
class Trainer:
    """
    Training pipeline

    Parameters
    ----------
    num_epochs : int
        We should train the model for __ epochs

    start_epoch : int
        We should start training the model from __th epoch

    train_loader : DataLoader
        DataLoader for training data

    model : nn.Module
        Model

    model_name : str
        Name of the model

    loss_function : nn.Module
        Loss function (cross entropy)

    optimizer : optim.Optimizer
        Optimizer (Adam)

    lr_decay : float
        A factor in interval (0, 1) to multiply the learning rate with

    dataset_name : str
        Name of the dataset

    word_map : Dict[str, int]
        Word2id map

    grad_clip : float, optional
        Gradient threshold in clip gradients

    print_freq : int
        Print training status every __ batches

    checkpoint_path : str, optional
        Path to the folder to save checkpoints

    checkpoint_basename : str, optional, default='checkpoint'
        Basename of the checkpoint

    tensorboard : bool, optional, default=False
        Enable tensorboard or not?

    log_dir : str, optional
        Path to the folder to save logs for tensorboard
    """
    def __init__(self,
                 num_epochs: int,
                 start_epoch: int,
                 train_loader: DataLoader,
                 model: nn.Module,
                 model_name: str,
                 loss_function: nn.Module,
                 optimizer,
                 lr_decay: float,
                 dataset_name: str,
                 word_map: Dict[str, int],
                 grad_clip=Optional[None],
                 print_freq: int = 100,
                 checkpoint_path: Optional[str] = None,
                 checkpoint_basename: str = 'checkpoint',
                 tensorboard: bool = False,
                 log_dir: Optional[str] = None) -> None:
        self.num_epochs = num_epochs
        self.start_epoch = start_epoch
        self.train_loader = train_loader

        self.model = model
        self.model_name = model_name
        self.loss_function = loss_function
        self.optimizer = optimizer
        self.lr_decay = lr_decay

        self.dataset_name = dataset_name
        self.word_map = word_map
        self.print_freq = print_freq
        self.grad_clip = grad_clip

        self.checkpoint_path = checkpoint_path
        self.checkpoint_basename = checkpoint_basename

        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)

    def train(self, epoch: int) -> None:
        """
        Train an epoch

        Parameters
        ----------
        epoch : int
            Current number of epoch
        """
        self.model.train()  # training mode enables dropout

        batch_time = AverageMeter(
        )  # forward prop. + back prop. time per batch
        data_time = AverageMeter()  # data loading time per batch
        losses = AverageMeter(tag='loss',
                              writer=self.writer)  # cross entropy loss
        accs = AverageMeter(tag='acc', writer=self.writer)  # accuracies

        start = time.time()

        # batches
        for i, batch in enumerate(self.train_loader):
            data_time.update(time.time() - start)

            if self.model_name in ['han']:
                documents, sentences_per_document, words_per_sentence, labels = batch

                documents = documents.to(
                    device)  # (batch_size, sentence_limit, word_limit)
                sentences_per_document = sentences_per_document.squeeze(1).to(
                    device)  # (batch_size)
                words_per_sentence = words_per_sentence.to(
                    device)  # (batch_size, sentence_limit)
                labels = labels.squeeze(1).to(device)  # (batch_size)

                # forward
                scores, _, _ = self.model(
                    documents, sentences_per_document, words_per_sentence
                )  # (n_documents, n_classes), (n_documents, max_doc_len_in_batch, max_sent_len_in_batch), (n_documents, max_doc_len_in_batch)

            else:
                sentences, words_per_sentence, labels = batch

                sentences = sentences.to(device)  # (batch_size, word_limit)
                words_per_sentence = words_per_sentence.squeeze(1).to(
                    device)  # (batch_size)
                labels = labels.squeeze(1).to(device)  # (batch_size)

                # for torchtext
                # sentences = batch.text[0].to(device)  # (batch_size, word_limit)
                # words_per_sentence = batch.text[1].to(device)  # (batch_size)
                # labels = batch.label.to(device)  # (batch_size)

                scores = self.model(
                    sentences, words_per_sentence)  # (batch_size, n_classes)

            # calc loss
            loss = self.loss_function(scores, labels)  # scalar

            # backward
            self.optimizer.zero_grad()
            loss.backward()

            # clip gradients
            if self.grad_clip is not None:
                clip_gradient(self.optimizer, grad_clip)

            # update weights
            self.optimizer.step()

            # find accuracy
            _, predictions = scores.max(dim=1)  # (n_documents)
            correct_predictions = torch.eq(predictions, labels).sum().item()
            accuracy = correct_predictions / labels.size(0)

            # set step for tensorboard
            step = (epoch - 1) * self.len_epoch + i
            self.writer.set_step(step=step, mode='train')

            # keep track of metrics
            batch_time.update(time.time() - start)
            losses.update(loss.item(), labels.size(0))
            accs.update(accuracy, labels.size(0))

            start = time.time()

            # print training status
            if i % self.print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Accuracy {acc.val:.3f} ({acc.avg:.3f})'.format(
                        epoch,
                        i,
                        len(self.train_loader),
                        batch_time=batch_time,
                        data_time=data_time,
                        loss=losses,
                        acc=accs))

    def run_train(self):
        start = time.time()

        # epochs
        for epoch in range(self.start_epoch, self.num_epochs):
            # trian an epoch
            self.train(epoch=epoch)

            # time per epoch
            epoch_time = time.time() - start
            print('Epoch: [{0}] finished, time consumed: {epoch_time:.3f}'.
                  format(epoch, epoch_time=epoch_time))

            # decay learning rate every epoch
            adjust_learning_rate(self.optimizer, self.lr_decay)

            # save checkpoint
            if self.checkpoint_path is not None:
                save_checkpoint(epoch=epoch,
                                model=self.model,
                                model_name=self.model_name,
                                optimizer=self.optimizer,
                                dataset_name=self.dataset_name,
                                word_map=self.word_map,
                                checkpoint_path=self.checkpoint_path,
                                checkpoint_basename=self.checkpoint_basename)

            start = time.time()
Ejemplo n.º 3
0
class Trainer:
    """
    Encoder-decoder pipeline. Tearcher Forcing is used during training and validation.

    Parameters
    ----------
    caption_model : str
        Type of the caption model

    epochs : int
        We should train the model for __ epochs

    device : torch.device
        Use GPU or not

    word_map : Dict[str, int]
        Word2id map

    rev_word_map : Dict[int, str]
        Id2word map

    start_epoch : int
        We should start training the model from __th epoch

    epochs_since_improvement : int
        Number of epochs since last improvement in BLEU-4 score

    best_bleu4 : float
        Best BLEU-4 score until now

    train_loader : DataLoader
        DataLoader for training data

    val_loader : DataLoader
        DataLoader for validation data

    encoder : nn.Module
        Encoder (based on CNN)

    decoder : nn.Module
        Decoder (based on LSTM)

    encoder_optimizer : optim.Optimizer
        Optimizer for encoder (Adam) (if fine-tune)

    decoder_optimizer : optim.Optimizer
        Optimizer for decoder (Adam)

    loss_function : nn.Module
        Loss function (cross entropy)

    grad_clip : float
        Gradient threshold in clip gradients

    tau : float
        Penalty term τ for doubly stochastic attention in paper: show, attend and tell

    fine_tune_encoder : bool
        Fine-tune encoder or not

    tensorboard : bool, optional, default=False
        Enable tensorboard or not?

    log_dir : str, optional
        Path to the folder to save logs for tensorboard
    """
    def __init__(
        self,
        caption_model: str,
        epochs: int,
        device: torch.device,
        word_map: Dict[str, int],
        rev_word_map: Dict[int, str],
        start_epoch: int,
        epochs_since_improvement: int,
        best_bleu4: float,
        train_loader: DataLoader,
        val_loader: DataLoader,
        encoder: nn.Module,
        decoder: nn.Module,
        encoder_optimizer: optim.Optimizer,
        decoder_optimizer: optim.Optimizer,
        loss_function: nn.Module,
        grad_clip: float,
        tau: float,
        fine_tune_encoder: bool,
        tensorboard: bool = False,
        log_dir: Optional[str] = None
    ) -> None:
        self.device = device  # GPU / CPU

        self.caption_model = caption_model
        self.epochs = epochs
        self.word_map = word_map
        self.rev_word_map = rev_word_map

        self.start_epoch = start_epoch
        self.epochs_since_improvement = epochs_since_improvement
        self.best_bleu4 = best_bleu4

        self.train_loader =  train_loader
        self.val_loader = val_loader
        self.encoder = encoder
        self.decoder = decoder
        self.encoder_optimizer = encoder_optimizer
        self.decoder_optimizer = decoder_optimizer
        self.loss_function = loss_function

        self.tau = tau
        self.grad_clip = grad_clip
        self.fine_tune_encoder = fine_tune_encoder

        self.print_freq = 100  # print training/validation stats every __ batches
        # setup visualization writer instance
        self.writer = TensorboardWriter(log_dir, tensorboard)
        self.len_epoch = len(self.train_loader)

    def train(self, epoch: int) -> None:
        """
        Train an epoch

        Parameters
        ----------
        epoch : int
            Current number of epoch
        """
        self.decoder.train()  # train mode (dropout and batchnorm is used)
        self.encoder.train()

        batch_time = AverageMeter()  # forward prop. + back prop. time
        data_time = AverageMeter()  # data loading time
        losses = AverageMeter(tag='loss', writer=self.writer)  # loss (per word decoded)
        top5accs = AverageMeter(tag='top5acc', writer=self.writer)  # top5 accuracy

        start = time.time()

        # batches
        for i, (imgs, caps, caplens) in enumerate(self.train_loader):
            data_time.update(time.time() - start)

            # Move to GPU, if available
            imgs = imgs.to(self.device)
            caps = caps.to(self.device)
            caplens = caplens.to(self.device)

            # forward encoder
            imgs = self.encoder(imgs)

            # forward decoder
            if self.caption_model == 'att2all':
                scores, caps_sorted, decode_lengths, alphas, sort_ind = self.decoder(imgs, caps, caplens)
            else:
                scores, caps_sorted, decode_lengths, sort_ind = self.decoder(imgs, caps, caplens)

            # since we decoded starting with <start>, the targets are all words after <start>, up to <end>
            targets = caps_sorted[:, 1:]

            # remove timesteps that we didn't decode at, or are pads
            # pack_padded_sequence is an easy trick to do this
            scores = pack_padded_sequence(scores, decode_lengths, batch_first=True)[0]
            targets = pack_padded_sequence(targets, decode_lengths, batch_first=True)[0]

            # calc loss
            loss = self.loss_function(scores, targets)

            # doubly stochastic attention regularization (in paper: show, attend and tell)
            if self.caption_model == 'att2all':
                loss += self.tau * ((1. - alphas.sum(dim = 1)) ** 2).mean()

            # clear gradient of last batch
            self.decoder_optimizer.zero_grad()
            if self.encoder_optimizer is not None:
                self.encoder_optimizer.zero_grad()

            # backward
            loss.backward()

            # clip gradients
            if self.grad_clip is not None:
                clip_gradient(self.decoder_optimizer, self.grad_clip)
                if self.encoder_optimizer is not None:
                    clip_gradient(self.encoder_optimizer, self.grad_clip)

            # update weights
            self.decoder_optimizer.step()
            if self.encoder_optimizer is not None:
                self.encoder_optimizer.step()

            # set step for tensorboard
            step = (epoch - 1) * self.len_epoch + i
            self.writer.set_step(step=step, mode='train')

            # keep track of metrics
            top5 = accuracy(scores, targets, 5)
            losses.update(loss.item(), sum(decode_lengths))
            top5accs.update(top5, sum(decode_lengths))
            batch_time.update(time.time() - start)

            start = time.time()

            # print status
            if i % self.print_freq == 0:
                print(
                    'Epoch: [{0}][{1}/{2}]\t'
                    'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                    'Data Load Time {data_time.val:.3f} ({data_time.avg:.3f})\t'
                    'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                    'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})'.format(
                        epoch, i, len(self.train_loader),
                        batch_time = batch_time,
                        data_time = data_time,
                        loss = losses,
                        top5 = top5accs
                    )
                )

    def validate(self) -> float:
        """
        Validate an epoch.

        Returns
        -------
        bleu4 : float
            BLEU-4 score
        """
        self.decoder.eval()  # eval mode (no dropout or batchnorm)
        if self.encoder is not None:
            self.encoder.eval()

        batch_time = AverageMeter()
        losses = AverageMeter()
        top5accs = AverageMeter()

        start = time.time()

        ground_truth = list()  # ground_truth (true captions) for calculating BLEU-4 score
        prediction = list()  # prediction (predicted captions)

        # explicitly disable gradient calculation to avoid CUDA memory error
        # solves the issue #57
        with torch.no_grad():
            # Batches
            for i, (imgs, caps, caplens, allcaps) in enumerate(self.val_loader):

                # move to device, if available
                imgs = imgs.to(self.device)
                caps = caps.to(self.device)
                caplens = caplens.to(self.device)

                # forward encoder
                if self.encoder is not None:
                    imgs = self.encoder(imgs)

                # forward decoder
                if self.caption_model == 'att2all':
                    scores, caps_sorted, decode_lengths, alphas, sort_ind = self.decoder(imgs, caps, caplens)
                else:
                    scores, caps_sorted, decode_lengths, sort_ind = self.decoder(imgs, caps, caplens)

                # since we decoded starting with <start>, the targets are all words after <start>, up to <end>
                targets = caps_sorted[:, 1:]

                # remove timesteps that we didn't decode at, or are pads
                # pack_padded_sequence is an easy trick to do this
                scores_copy = scores.clone()
                scores = pack_padded_sequence(scores, decode_lengths, batch_first = True)[0]
                targets = pack_padded_sequence(targets, decode_lengths, batch_first = True)[0]

                # calc loss
                loss = self.loss_function(scores, targets)

                # doubly stochastic attention regularization (in paper: show, attend and tell)
                if self.caption_model == 'att2all':
                    loss += self.tau * ((1. - alphas.sum(dim = 1)) ** 2).mean()

                # keep track of metrics
                losses.update(loss.item(), sum(decode_lengths))
                top5 = accuracy(scores, targets, 5)
                top5accs.update(top5, sum(decode_lengths))
                batch_time.update(time.time() - start)

                start = time.time()

                if i % self.print_freq == 0:
                    print('Validation: [{0}/{1}]\t'
                        'Batch Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                        'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                        'Top-5 Accuracy {top5.val:.3f} ({top5.avg:.3f})\t'.format(i, len(self.val_loader),
                                                                                    batch_time = batch_time,
                                                                                    loss = losses,
                                                                                    top5 = top5accs)
                    )

                # store ground truth captions and predicted captions of each image
                # for n images, each of them has one prediction and multiple ground truths (a, b, c...):
                # prediction = [ [hyp1], [hyp2], ..., [hypn] ]
                # ground_truth = [ [ [ref1a], [ref1b], [ref1c] ], ..., [ [refna], [refnb] ] ]

                # ground truth
                allcaps = allcaps[sort_ind]  # because images were sorted in the decoder
                for j in range(allcaps.shape[0]):
                    img_caps = allcaps[j].tolist()
                    img_captions = list(
                        map(
                            lambda c: [w for w in c if w not in {self.word_map['<start>'], self.word_map['<pad>']}],
                            img_caps
                        )
                    )  # remove <start> and pads
                    ground_truth.append(img_captions)

                # prediction
                _, preds = torch.max(scores_copy, dim = 2)
                preds = preds.tolist()
                temp_preds = list()
                for j, p in enumerate(preds):
                    temp_preds.append(preds[j][:decode_lengths[j]])  # remove pads
                preds = temp_preds
                prediction.extend(preds)

                assert len(ground_truth) == len(prediction)

            # calc BLEU-4 and CIDEr score
            metrics = Metrics(ground_truth, prediction, self.rev_word_map)
            bleu4 = metrics.belu[3]  # BLEU-4
            cider = metrics.cider  # CIDEr

            print(
                '\n * LOSS - {loss.avg:.3f}, TOP-5 ACCURACY - {top5.avg:.3f}, BLEU-4 - {bleu}, CIDEr - {cider}\n'.format(
                    loss = losses,
                    top5 = top5accs,
                    bleu = bleu4,
                    cider = cider
                )
            )

        return bleu4

    def run_train(self) -> None:
        # epochs
        for epoch in range(self.start_epoch, self.epochs):

            # decay learning rate if there is no improvement for 8 consecutive epochs
            # terminate training if there is no improvement for 20 consecutive epochs
            if self.epochs_since_improvement == 20:
                break
            if self.epochs_since_improvement > 0 and self.epochs_since_improvement % 8 == 0:
                adjust_learning_rate(self.decoder_optimizer, 0.8)
                if self.fine_tune_encoder:
                    adjust_learning_rate(self.encoder_optimizer, 0.8)

            # train an epoch
            self.train(epoch = epoch)

            # validate an epoch
            recent_bleu4 = self.validate()

            # epochs num since last improvement
            is_best = recent_bleu4 > self.best_bleu4
            self.best_bleu4 = max(recent_bleu4, self.best_bleu4)
            if not is_best:
                self.epochs_since_improvement += 1
                print("\nEpochs since last improvement: %d\n" % (self.epochs_since_improvement,))
            else:
                self.epochs_since_improvement = 0

            # save checkpoint
            save_checkpoint(
                epoch = epoch,
                epochs_since_improvement = self.epochs_since_improvement,
                encoder = self.encoder,
                decoder = self.decoder,
                encoder_optimizer = self.encoder_optimizer,
                decoder_optimizer = self.decoder_optimizer,
                caption_model = self.caption_model,
                bleu4 = recent_bleu4,
                is_best = is_best
            )