Exemple #1
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    args.warmup_steps = t_total // 100

    # Prepare optimizer and schedule (linear warmup and decay)
    optimizer_grouped_parameters = get_param_groups(args, model)
    optimizer = RAdam(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    args.logging_steps = len(train_dataloader) // 1
    args.save_steps = args.logging_steps
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc="Epoch")
    set_seed(args)
    for _ in train_iterator:
        args.current_epoch = _
        epoch_iterator = tqdm(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids':
                batch[0],
                'attention_mask':
                batch[1],
                'token_type_ids':
                batch[2] if args.model_type in ['bert', 'xlnet'] else None,
            }  # XLM and RoBERTa don't use segment_ids
            #   'labels':         batch[3]}
            outputs = model(**inputs)
            outputs = [outputs[i][0] for i in range(len(outputs))]

            loss_fct = CrossEntropyLoss()
            loss_fct = DataParallelCriterion(loss_fct)

            loss = loss_fct(outputs, batch[3])

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Exemple #2
0
class LMTrainer:
    def __init__(self,
                 model,
                 mask_prob: float = 0.15,
                 clip: int = 1,
                 optimizer=None):
        self.model = model
        self.clip = clip
        self.optimizer = optimizer

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.model = self.model.to(self.device)

        self.mask_prob = mask_prob
        self.criterion = nn.NLLLoss(
            ignore_index=model.text_processor.pad_token_id())

        num_gpu = torch.cuda.device_count()
        if num_gpu > 1:
            print("Let's use", num_gpu, "GPUs!")
            self.model = DataParallelModel(self.model)
            self.criterion = DataParallelCriterion(self.criterion)

        self.best_dev_loss = float("inf")
        self.best_train_loss = float("inf")
        self.last_train_loss = float("inf")

    def train_epoch(self, data_iter: data_utils.DataLoader,
                    dev_data_iter: data_utils.DataLoader, saving_path: str,
                    step: int):
        "Standard Training and Logging Function"
        start = time.time()
        total_tokens, total_loss, tokens, cur_loss = 0, 0, 0, 0
        cur_loss = 0
        model = self.model.module if hasattr(self.model,
                                             "module") else self.model

        for i, batch in enumerate(data_iter):
            if self.optimizer is not None:
                self.optimizer.zero_grad()
            mask, target, texts = mask_text(self.mask_prob, batch["pad_mask"],
                                            batch["texts"],
                                            model.text_processor)
            try:
                predictions = self.model(mask=mask,
                                         texts=texts,
                                         pads=batch["pad_mask"],
                                         langs=batch["langs"])
                ntokens = target.size(0)

                if ntokens == 0:  # Nothing to predict!
                    continue

                loss = self.criterion(predictions, target).mean()
                loss.backward()

                unmask_text(mask, target, texts)

                if self.optimizer is not None:
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.clip)

                    self.optimizer.step()
                    step += 1

                loss = float(loss.data) * ntokens
                total_loss += loss
                cur_loss += loss
                total_tokens += ntokens
                tokens += ntokens

                if step % 50 == 0:
                    elapsed = time.time() - start
                    print(
                        datetime.datetime.now(),
                        "Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                        (step, cur_loss / tokens, tokens / elapsed))

                    if step % 500 == 0:
                        self.validate_and_save(saving_path, dev_data_iter)

                    start, tokens, cur_loss = time.time(), 0, 0
            except RuntimeError as err:
                print("Problem with batch item", texts.size())
                torch.cuda.empty_cache()
                pass

        current_loss = total_loss / total_tokens
        print("Total loss in this epoch: %f" % current_loss)
        if current_loss < self.best_train_loss:
            self.best_train_loss = current_loss
            model_to_save = (self.model.module if hasattr(
                self.model, "module") else self.model)
            model_to_save.save(saving_path + ".latest")
            with open(os.path.join(saving_path + ".latest", "optim"),
                      "wb") as fp:
                pickle.dump(self.optimizer, fp)
        self.last_train_loss = current_loss

        self.validate_and_save(saving_path, dev_data_iter)
        return step

    def validate_and_save(self, saving_path, dev_data_iter):
        with torch.no_grad():
            model = self.model.module if hasattr(self.model,
                                                 "module") else self.model
            model.eval()
            total_dev_loss, total_dev_tokens = 0, 0
            for batch in dev_data_iter:
                mask, target, texts = mask_text(self.mask_prob,
                                                batch["pad_mask"],
                                                batch["texts"].clone(),
                                                model.text_processor)
                predictions = self.model(mask=mask,
                                         texts=texts,
                                         pads=batch["pad_mask"],
                                         langs=batch["langs"])
                ntokens = target.size(0)

                if ntokens == 0:  # Nothing to predict!
                    continue
                loss = self.criterion(predictions,
                                      target).mean().data * ntokens
                total_dev_loss += float(loss)
                total_dev_tokens += ntokens

            dev_loss = total_dev_loss / total_dev_tokens
            print("Current dev loss", dev_loss)
            if self.best_dev_loss > float(dev_loss):
                self.best_dev_loss = float(dev_loss)
                print("saving best dev loss", self.best_dev_loss)
                model_to_save = (self.model.module if hasattr(
                    self.model, "module") else self.model)
                model_to_save.save(saving_path)
                with open(os.path.join(saving_path, "optim"), "wb") as fp:
                    pickle.dump(self.optimizer, fp)
            model.train()

    @staticmethod
    def config_dropout(model, dropout):
        model.encoder.config.hidden_dropout_prob = dropout
        model.encoder.config.attention_probs_dropout_prob = dropout

    @staticmethod
    def train(options):
        if not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)

        lm_class = ReformerLM if options.reformer else LM
        if options.pretrained_path is None:
            lm = lm_class(text_processor=text_processor,
                          size=options.model_size)
        else:
            lm = lm_class.load(options.pretrained_path)

        if options.reformer:
            lm.config.hidden_dropout_prob = options.dropout
            lm.config.local_attention_probs_dropout_prob = options.dropout
            lm.config.lsh_attention_probs_dropout_prob = options.dropout
        else:
            LMTrainer.config_dropout(lm, options.dropout)

        train_data = dataset.TextDataset(save_cache_dir=options.train_path,
                                         max_cache_size=options.cache_size)
        dev_data = dataset.TextDataset(save_cache_dir=options.dev_path,
                                       max_cache_size=options.cache_size,
                                       load_all=True)

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(lm, options.learning_rate,
                                        options.warmup)

        trainer = LMTrainer(model=lm,
                            mask_prob=options.mask_prob,
                            optimizer=optimizer,
                            clip=options.clip)

        collator = dataset.TextCollator(pad_idx=text_processor.pad_token_id())
        train_sampler, dev_sampler = None, None

        pin_memory = torch.cuda.is_available()
        loader = data_utils.DataLoader(train_data,
                                       batch_size=options.batch,
                                       shuffle=False,
                                       pin_memory=pin_memory,
                                       collate_fn=collator,
                                       sampler=train_sampler)
        dev_loader = data_utils.DataLoader(dev_data,
                                           batch_size=options.batch,
                                           shuffle=False,
                                           pin_memory=pin_memory,
                                           collate_fn=collator,
                                           sampler=dev_sampler)

        step, train_epoch = 0, 1
        while step <= options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(data_iter=loader,
                                       dev_data_iter=dev_loader,
                                       saving_path=options.model_path,
                                       step=step)
Exemple #3
0
class BERTTrainer:
    """
    BERTTrainer make the pretrained BERT model with two LM training method.
        1. Masked Language Model : 3.3.1 Task #1: Masked LM
        2. Next Sentence prediction : 3.3.2 Task #2: Next Sentence Prediction
    """
    def __init__(self,
                 model,
                 vocab_size,
                 train_dataloader,
                 test_dataloader=None,
                 lr: float = 1e-4,
                 betas=(0.9, 0.999),
                 weight_decay: float = 0.01,
                 warmup_steps=10000,
                 with_cuda: bool = True,
                 cuda_devices=None,
                 log_freq: int = 10,
                 include_next=False,
                 include_vision=True,
                 total_epochs=1):
        """
        :param bert: BERT model which you want to train
        :param vocab_size: total word vocab size
        :param train_dataloader: train dataset data loader
        :param test_dataloader: test dataset data loader [can be None]
        :param lr: learning rate of optimizer
        :param betas: Adam optimizer betas
        :param weight_decay: Adam optimizer weight decay param
        :param with_cuda: traning with cuda
        :param log_freq: logging frequency of the batch iteration
        """

        # Setup cuda device for BERT training, argument -c, --cuda should be true
        cuda_condition = torch.cuda.is_available() and with_cuda
        self.device = torch.device("cuda:0" if cuda_condition else "cpu")

        n_gpu = torch.cuda.device_count()
        print("device", device, "n_gpu", n_gpu)

        # Initialize the BERT Language Model, with BERT model
        self.model = model.to(self.device)
        self.bert = self.model.bert
        self.padding_idx = 0
        self.include_next = include_next
        self.include_vision = include_vision

        # Distributed GPU training if CUDA can detect more than 1 GPU
        if with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            #self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
            self.model = DataParallelModel(self.model,
                                           device_ids=range(
                                               torch.cuda.device_count()))

        # Setting the train and test data loader
        self.train_data = train_dataloader
        self.test_data = test_dataloader

        # Setting the Adam optimizer with hyper-param
        self.optim = optim.Adamax(self.model.parameters(),
                                  lr=lr,
                                  betas=betas,
                                  weight_decay=weight_decay)
        if self.model.__class__.__name__ in [
                'DataParallel', 'DataParallelModel'
        ]:
            self.optim_schedule = ScheduledOptim(
                self.optim,
                self.model.module.bert.transformer_hidden_size,
                n_warmup_steps=warmup_steps)
        else:
            self.optim_schedule = ScheduledOptim(
                self.optim,
                self.model.bert.transformer_hidden_size,
                n_warmup_steps=warmup_steps)

        # Using Negative Log Likelihood Loss function for predicting the masked_token
        self.criterion = nn.NLLLoss(ignore_index=0)
        if with_cuda and torch.cuda.device_count() > 1:
            print("Using %d GPUS for BERT" % torch.cuda.device_count())
            #self.model = nn.DataParallel(self.model, device_ids=range(torch.cuda.device_count()))
            self.criterion = DataParallelCriterion(
                self.criterion, device_ids=range(torch.cuda.device_count()))

        self.log_freq = log_freq
        self.total_iters = total_epochs * len(train_dataloader)

        print("Total Parameters:",
              sum([p.nelement() for p in self.model.parameters()]))

    def train(self, epoch):
        self.iteration(epoch, self.train_data)

    def test(self, epoch):
        self.iteration(epoch, self.test_data, train=False)

    def iteration(self, epoch, data_loader, train=True):
        """
        loop over the data_loader for training or testing
        if on train status, backward operation is activated
        and also auto save the model every peoch
        :param epoch: current epoch index
        :param data_loader: torch.utils.data.DataLoader for iteration
        :param train: boolean value of is train or test
        :return: None
        """
        str_code = "train" if train else "test"

        # Setting the tqdm progress bar
        data_iter = tqdm.tqdm(enumerate(data_loader),
                              desc="EP_%s:%d" % (str_code, epoch),
                              total=len(data_loader),
                              bar_format="{l_bar}{r_bar}",
                              disable=True)

        avg_loss = 0.0
        total_correct = 0
        total_element = 0

        for i, data in data_iter:
            # 0. prepare the text sequence tensor
            #data = {key: value.to(self.device) for key, value in data.items()}

            seq_tensor = data['masked_text_seq']
            labels = data['masked_text_label']

            seq_lengths = np.argmax(seq_tensor == self.padding_idx, axis=1)
            seq_lengths[seq_lengths == 0] = seq_tensor.shape[1]  # Full length

            # Sort sequences by lengths
            seq_lengths, perm_idx = seq_lengths.sort(0, True)
            sorted_tensor = seq_tensor[perm_idx]
            mask = (sorted_tensor == padding_idx)[:, :seq_lengths[0]]

            f_t_all = data['feature_all']
            isnext = data["isnext"]

            f_t_all = f_t_all[perm_idx]
            isnext = isnext[perm_idx]
            labels = labels[perm_idx]

            # 1. forward the next_sentence_prediction and masked_lm model

            if self.include_vision:
                #next_sent_output, mask_lm_output = self.model.forward(sorted_tensor.cuda(), mask.cuda(),seq_lengths.cuda(),f_t_all.cuda())
                output = self.model.forward(sorted_tensor.cuda(), mask.cuda(),
                                            seq_lengths.cuda(), f_t_all.cuda())
                length_output = len(output)
                print("You got %d outputs" % (length_output))
                next_sent_output, mask_lm_output = zip(*output)
                print("vision test shape is %d " % (next_sent_output[1].shape))
                print("lm test shape is %d " % (mask_lm_output[1].shape))
            else:
                #next_sent_output, mask_lm_output = self.model.forward(sorted_tensor.cuda(), mask.cuda(),seq_lengths.cuda(),None)
                output = self.model.forward(sorted_tensor.cuda(), mask.cuda(),
                                            seq_lengths.cuda(), None)
                length_output = len(output)
                print("You got %d outputs" % (length_output))
                next_sent_output, mask_lm_output = zip(*output)

            # 2-1. NLL(negative log likelihood) loss of is_next classification result
            next_loss = 0
            if self.include_vision and self.include_next:
                next_loss = self.criterion(next_sent_output, isnext.cuda())

            # 2-2. NLLLoss of predicting masked token word
            mask_loss = self.criterion(mask_lm_output.transpose(1, 2),
                                       labels[:, :seq_lengths[0]].cuda())

            # 2-3. Adding next_loss and mask_loss : 3.4 Pre-training Procedure
            #loss = next_loss + mask_loss

            # 3. backward and optimization only in train
            loss = loss.mean()

            if train:
                self.optim_schedule.zero_grad()
                loss.backward()
                self.optim_schedule.step_and_update_lr()

            # next vision prediction accuracy
            if self.include_next:
                correct = next_sent_output.argmax(dim=-1).eq(
                    isnext.cuda()).sum().item()
                total_correct += correct
                total_element += data["isnext"].nelement()
            avg_loss += loss.item()

            if self.include_next:
                post_fix = {
                    "epoch": epoch,
                    "iter": i,
                    "avg_loss": avg_loss / (i + 1),
                    "avg_acc": total_correct / total_element * 100,
                    "loss": loss.item()
                }
            else:
                post_fix = {
                    "epoch": epoch,
                    "iter": i,
                    "avg_loss": avg_loss / (i + 1),
                    "loss": loss.item()
                }

            #if i % self.log_freq == 0:
            #    data_iter.write(str(post_fix))

            if i % 100 == 0:
                #print("PROGRESS: {}%".format(round((myidx) * 100 / n_iters, 4)))
                print("\n")
                print("PROGRESS: {}%".format(
                    round((epoch * len(data_loader) + i) * 100 /
                          self.total_iters, 4)))
                print("EVALERR: {}%".format(avg_loss / (i + 1)))

        #print("EP%d_%s, avg_loss=" % (epoch, str_code), avg_loss / len(data_iter))

    def save(self, epoch, file_path="pretrained_models/addbert_trained.model"):
        """
        Saving the current BERT model on file_path
        :param epoch: current epoch number
        :param file_path: model output path which gonna be file_path+"ep%d" % epoch
        :return: final_output_path
        """
        output_path = file_path + ".ep%d" % epoch
        torch.save(self.bert.cpu(), output_path)
        self.bert.to(self.device)
        print("EP:%d Model Saved on:" % epoch, output_path)
        return output_path
Exemple #4
0
class Trainer:
    """
    trainer class
    """
    def __init__(self, cfg: Namespace, data: Dataset):
        """
        Args:
            cfg:  configuration
            data:  train dataset
        """
        self.cfg = cfg
        self.train, self.valid = data.split(0.8)
        RATING_FIELD.build_vocab(self.train)

        self.device = torch.device(
            'cuda') if torch.cuda.is_available() else torch.device('cpu')  # pylint: disable=no-member
        self.batch_size = cfg.batch_size
        if torch.cuda.is_available():
            self.batch_size *= torch.cuda.device_count()

        self.trn_itr = BucketIterator(
            self.train,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=True,
            train=True,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.vld_itr = BucketIterator(
            self.valid,
            device=self.device,
            batch_size=self.batch_size,
            shuffle=False,
            train=False,
            sort_within_batch=True,
            sort_key=lambda exam: -len(exam.comment_text))
        self.log_step = 1000
        if len(self.vld_itr) < 100:
            self.log_step = 10
        elif len(self.vld_itr) < 1000:
            self.log_step = 100

        bert_path = cfg.bert_path if cfg.bert_path else 'bert-base-cased'
        self.model = BertForSequenceClassification.from_pretrained(
            bert_path, num_labels=2)
        pos_weight = (
            len([exam for exam in self.train.examples if exam.target < 0.5]) /
            len([exam for exam in self.train.examples if exam.target >= 0.5]))
        pos_wgt_tensor = torch.tensor([1.0, pos_weight], device=self.device)  # pylint: disable=not-callable
        self.criterion = nn.CrossEntropyLoss(weight=pos_wgt_tensor)
        if torch.cuda.is_available():
            self.model = DataParallelModel(self.model.cuda())
            self.criterion = DataParallelCriterion(self.criterion)
        self.optimizer = optim.Adam(self.model.parameters(), cfg.learning_rate)

    def run(self):
        """
        do train
        """
        max_f_score = -9e10
        max_epoch = -1
        for epoch in range(self.cfg.epoch):
            train_loss = self._train_epoch(epoch)
            metrics = self._evaluate(epoch)
            max_f_score_str = f' < {max_f_score:.2f}'
            if metrics['f_score'] > max_f_score:
                max_f_score_str = ' is max'
                max_f_score = metrics['f_score']
                max_epoch = epoch
                torch.save(self.model.state_dict(), self.cfg.model_path)
            logging.info('EPOCH[%d]: train loss: %.6f, valid loss: %.6f, acc: %.2f,' \
                         ' F: %.2f%s', epoch, train_loss, metrics['loss'],
                         metrics['accuracy'], metrics['f_score'], max_f_score_str)
            if (epoch - max_epoch) >= self.cfg.patience:
                logging.info('early stopping...')
                break
        logging.info('epoch: %d, f-score: %.2f', max_epoch, max_f_score)

    def _train_epoch(self, epoch: int) -> float:
        """
        train single epoch
        Args:
            epoch:  epoch number
        Returns:
            average loss
        """
        self.model.train()
        progress = tqdm(self.trn_itr,
                        f'EPOCH[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        for step, batch in enumerate(progress, start=1):
            outputs = self.model(batch.comment_text)
            # output of model wrapped with DataParallelModel is a list of outputs from each GPU
            # make input of DataParallelCriterion as a list of tuples
            if isinstance(self.model, DataParallelModel):
                loss = self.criterion([(output, ) for output in outputs],
                                      batch.target)
            else:
                loss = self.criterion(outputs, batch.target)
            losses.append(loss.item())
            if step % self.log_step == 0:
                avg_loss = sum(losses) / len(losses)
                progress.set_description(f'EPOCH[{epoch}] ({avg_loss:.6f})')
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
        return sum(losses) / len(losses)

    def _evaluate(self, epoch: int) -> Dict[str, float]:
        """
        evaluate on validation data
        Args:
            epoch:  epoch number
        Returns:
            metrics
        """
        self.model.eval()
        progress = tqdm(self.vld_itr,
                        f' EVAL[{epoch}]',
                        mininterval=1,
                        ncols=100)
        losses = []
        preds = []
        golds = []
        for step, batch in enumerate(progress, start=1):
            with torch.no_grad():
                outputs = self.model(batch.comment_text)
                if isinstance(self.model, DataParallelModel):
                    loss = self.criterion([(output, ) for output in outputs],
                                          batch.target)
                    for output in outputs:
                        preds.extend([(0 if o[0] < o[1] else 1)
                                      for o in output])
                else:
                    loss = self.criterion(outputs, batch.target)
                    preds.extend([(0 if output[0] < output[1] else 1)
                                  for output in outputs])
                losses.append(loss.item())
                golds.extend([gold.item() for gold in batch.target])
                if step % self.log_step == 0:
                    avg_loss = sum(losses) / len(losses)
                    progress.set_description(
                        f' EVAL[{epoch}] ({avg_loss:.6f})')
        metrics = self._get_metrics(preds, golds)
        metrics['loss'] = sum(losses) / len(losses)
        return metrics

    @classmethod
    def _get_metrics(cls, preds: List[float],
                     golds: List[float]) -> Dict[str, float]:
        """
        get metric values
        Args:
            preds:  predictions
            golds:  gold standards
        Returns:
            metric
        """
        assert len(preds) == len(golds)
        true_pos = 0
        false_pos = 0
        false_neg = 0
        true_neg = 0
        for pred, gold in zip(preds, golds):
            if pred >= 0.5:
                if gold >= 0.5:
                    true_pos += 1
                else:
                    false_pos += 1
            else:
                if gold >= 0.5:
                    false_neg += 1
                else:
                    true_neg += 1
        accuracy = (true_pos + true_neg) / (true_pos + false_pos + false_neg +
                                            true_neg)
        precision = 0.0
        if (true_pos + false_pos) > 0:
            precision = true_pos / (true_pos + false_pos)
        recall = 0.0
        if (true_pos + false_neg) > 0:
            recall = true_pos / (true_pos + false_neg)
        f_score = 0.0
        if (precision + recall) > 0.0:
            f_score = 2.0 * precision * recall / (precision + recall)
        return {
            'accuracy': 100.0 * accuracy,
            'precision': 100.0 * precision,
            'recall': 100.0 * recall,
            'f_score': 100.0 * f_score,
        }
class ImageMTTrainer:
    def __init__(self,
                 model,
                 mask_prob: float = 0.3,
                 clip: int = 1,
                 optimizer=None,
                 beam_width: int = 5,
                 max_len_a: float = 1.1,
                 max_len_b: int = 5,
                 len_penalty_ratio: float = 0.8,
                 nll_loss: bool = False,
                 fp16: bool = False,
                 mm_mode="mixed",
                 rank: int = -1):
        self.model = model

        self.clip = clip
        self.optimizer = optimizer

        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")
        self.num_gpu = torch.cuda.device_count()

        self.mask_prob = mask_prob
        if nll_loss:
            self.criterion = nn.NLLLoss(
                ignore_index=model.text_processor.pad_token_id())
        else:
            self.criterion = SmoothedNLLLoss(
                ignore_index=model.text_processor.pad_token_id())

        self.num_gpu = torch.cuda.device_count()
        self.fp16 = False
        self.rank = rank
        if rank >= 0:
            self.device = torch.device('cuda', rank)
            torch.cuda.set_device(self.device)

        self.model = self.model.to(self.device)

        if fp16:
            self.model, self.optimizer = amp.initialize(self.model,
                                                        self.optimizer,
                                                        opt_level="O2")
            self.fp16 = True

        self.generator = BeamDecoder(self.model,
                                     beam_width=beam_width,
                                     max_len_a=max_len_a,
                                     max_len_b=max_len_b,
                                     len_penalty_ratio=len_penalty_ratio)
        if rank >= 0:
            self.model = DistributedDataParallel(self.model,
                                                 device_ids=[self.rank],
                                                 output_device=self.rank,
                                                 find_unused_parameters=True)
            self.generator = DistributedDataParallel(
                self.generator,
                device_ids=[self.rank],
                output_device=self.rank,
                find_unused_parameters=True)
        elif self.num_gpu > 1:
            print("Let's use", self.num_gpu, "GPUs!")
            self.model = DataParallelModel(self.model)
            self.criterion = DataParallelCriterion(self.criterion)
            self.generator = DataParallelModel(self.generator)

        self.reference = None
        self.best_bleu = -1.0
        self.mm_mode = mm_mode

    def train_epoch(self,
                    img_data_iter: List[data_utils.DataLoader],
                    step: int,
                    saving_path: str = None,
                    mass_data_iter: List[data_utils.DataLoader] = None,
                    mt_dev_iter: List[data_utils.DataLoader] = None,
                    mt_train_iter: List[data_utils.DataLoader] = None,
                    max_step: int = 300000,
                    accum=1,
                    beam_width=1,
                    fine_tune: bool = False,
                    lang_directions: dict = False,
                    lex_dict=None,
                    save_opt: bool = False,
                    **kwargs):
        "Standard Training and Logging Function"
        start = time.time()
        total_tokens, total_loss, tokens, cur_loss = 0, 0, 0, 0
        cur_loss = 0
        batch_zip, shortest = self.get_batch_zip(img_data_iter, mass_data_iter,
                                                 mt_train_iter)

        model = (self.model.module
                 if hasattr(self.model, "module") else self.model)
        self.optimizer.zero_grad()
        for i, batches in enumerate(batch_zip):
            for batch in batches:
                is_img_batch = isinstance(batch,
                                          list) and "captions" in batch[0]
                is_mass_batch = not is_img_batch and "dst_texts" not in batch
                is_contrastive = False
                try:
                    if fine_tune and (is_img_batch or is_mass_batch):
                        id2lid = lambda r: model.text_processor.languages[
                            model.text_processor.id2token(lang_directions[int(
                                r)])]
                        if is_mass_batch:
                            src_inputs = batch["src_texts"].squeeze(0)
                            src_pad_mask = src_inputs != model.text_processor.pad_token_id(
                            )
                            pad_indices = batch["pad_idx"].squeeze(0)
                            proposal = batch["proposal"].squeeze(
                                0) if lex_dict is not None else None
                            target_langs = torch.LongTensor([
                                lang_directions[int(l)]
                                for l in src_inputs[:, 0]
                            ])
                            dst_langs = torch.LongTensor(
                                [id2lid(l) for l in src_inputs[:, 0]])
                        else:
                            src_inputs = [b["captions"] for b in batch]
                            src_pad_mask = [b["caption_mask"] for b in batch]
                            pad_indices = [b["pad_idx"] for b in batch]
                            proposal = [
                                b["proposal"] if lex_dict is not None else None
                                for b in batch
                            ]
                            target_langs = [
                                torch.LongTensor([
                                    lang_directions[int(l)] for l in src[:, 0]
                                ]) for src in src_inputs
                            ]
                            dst_langs = [
                                torch.LongTensor(
                                    [id2lid(l) for l in src[:, 0]])
                                for src in src_inputs
                            ]
                        if len(src_inputs) < self.num_gpu:
                            continue

                        if is_mass_batch:
                            langs = batch["langs"].squeeze(0)
                        else:
                            langs = [b["langs"] for b in batch]

                        model.eval()
                        with torch.no_grad():
                            # We do not backpropagate the data generator following the MASS paper.
                            images = None
                            if is_img_batch:
                                images = [b["images"] for b in batch]
                            outputs = self.generator(
                                src_inputs=src_inputs,
                                src_sizes=pad_indices,
                                first_tokens=target_langs,
                                src_langs=langs,
                                tgt_langs=dst_langs,
                                pad_idx=model.text_processor.pad_token_id(),
                                src_mask=src_pad_mask,
                                unpad_output=False,
                                beam_width=beam_width,
                                images=images,
                                proposals=proposal)
                            if self.num_gpu > 1 and self.rank < 0:
                                if is_mass_batch:
                                    new_outputs = []
                                    for output in outputs:
                                        new_outputs += output
                                    outputs = new_outputs

                            if is_mass_batch or self.num_gpu <= 1:
                                translations = pad_sequence(
                                    outputs,
                                    batch_first=True,
                                    padding_value=model.text_processor.
                                    pad_token_id())
                                translation_proposals = None
                                if lex_dict is not None:
                                    translation_proposals = list(
                                        map(
                                            lambda o: dataset.
                                            get_lex_suggestions(
                                                lex_dict, o,
                                                model.text_processor.
                                                pad_token_id()), outputs))
                                    translation_proposals = pad_sequence(
                                        translation_proposals,
                                        batch_first=True,
                                        padding_value=model.text_processor.
                                        pad_token_id())
                                translation_pad_mask = (
                                    translations !=
                                    model.text_processor.pad_token_id())
                            else:
                                translation_proposals = None
                                if lex_dict is not None:
                                    translation_proposals = [
                                        pad_sequence(
                                            list(
                                                map(
                                                    lambda o: dataset.
                                                    get_lex_suggestions(
                                                        lex_dict, o,
                                                        model.text_processor.
                                                        pad_token_id()),
                                                    output)),
                                            batch_first=True,
                                            padding_value=model.text_processor.
                                            pad_token_id())
                                        for output in outputs
                                    ]

                                translations = [
                                    pad_sequence(output,
                                                 batch_first=True,
                                                 padding_value=model.
                                                 text_processor.pad_token_id())
                                    for output in outputs
                                ]
                                translation_pad_mask = [
                                    t != model.text_processor.pad_token_id()
                                    for t in translations
                                ]
                        model.train()

                        if is_mass_batch:
                            langs = batch["langs"].squeeze(0)
                        else:
                            langs = torch.cat([b["langs"] for b in batch])
                        # Now use it for back-translation loss.
                        predictions = model(
                            src_inputs=translations,
                            tgt_inputs=src_inputs,
                            src_pads=translation_pad_mask,
                            pad_idx=model.text_processor.pad_token_id(),
                            src_langs=dst_langs,
                            tgt_langs=langs,
                            proposals=translation_proposals,
                            log_softmax=True)
                        if is_mass_batch:
                            src_targets = src_inputs[:,
                                                     1:].contiguous().view(-1)
                            src_mask_flat = src_pad_mask[:,
                                                         1:].contiguous().view(
                                                             -1)
                        else:
                            src_targets = torch.cat(
                                list(map(lambda s: s[:, 1:], src_inputs)))
                            src_mask_flat = torch.cat(
                                list(map(lambda s: s[:, 1:], src_pad_mask)))
                        targets = src_targets[src_mask_flat]

                        ntokens = targets.size(0)
                    elif is_img_batch:
                        src_inputs = [b["captions"] for b in batch]
                        src_pad_mask = [b["caption_mask"] for b in batch]
                        proposals = [b["proposal"] for b in batch
                                     ] if lex_dict is not None else None
                        langs = [b["langs"] for b in batch]
                        if (self.mm_mode == "mixed" and random.random() <= .5
                            ) or self.mm_mode == "masked":
                            pad_indices = [b["pad_idx"] for b in batch]
                            if len(batch) < self.num_gpu:
                                continue

                            # For image masking, we are allowed to mask more than mask_prob
                            mask_prob = random.uniform(self.mask_prob, 1.0)

                            masked_info = list(
                                map(
                                    lambda pi, si: mass_mask(
                                        mask_prob, pi, si, model.text_processor
                                    ), pad_indices, src_inputs))
                            predictions = self.model(
                                src_inputs=list(
                                    map(lambda m: m["src_text"], masked_info)),
                                tgt_inputs=list(
                                    map(lambda m: m["to_recover"],
                                        masked_info)),
                                tgt_positions=list(
                                    map(lambda m: m["positions"],
                                        masked_info)),
                                src_pads=src_pad_mask,
                                pad_idx=model.text_processor.pad_token_id(),
                                src_langs=langs,
                                batch=batch,
                                proposals=proposals,
                                log_softmax=True)
                            targets = torch.cat(
                                list(map(lambda m: m["targets"], masked_info)))
                            ntokens = targets.size(0)
                        else:
                            neg_samples = [b["neg"] for b in batch]
                            neg_mask = [b["neg_mask"] for b in batch]
                            loss = self.model(
                                src_inputs=src_inputs,
                                src_pads=src_pad_mask,
                                neg_samples=neg_samples,
                                neg_mask=neg_mask,
                                pad_idx=model.text_processor.pad_token_id(),
                                src_langs=langs,
                                batch=batch,
                                proposals=proposals,
                                log_softmax=True)
                            is_contrastive = True

                    elif not is_mass_batch:  # MT data
                        src_inputs = batch["src_texts"].squeeze(0)
                        src_mask = batch["src_pad_mask"].squeeze(0)
                        tgt_inputs = batch["dst_texts"].squeeze(0)
                        tgt_mask = batch["dst_pad_mask"].squeeze(0)
                        src_langs = batch["src_langs"].squeeze(0)
                        dst_langs = batch["dst_langs"].squeeze(0)
                        proposals = batch["proposal"].squeeze(
                            0) if lex_dict is not None else None
                        if src_inputs.size(0) < self.num_gpu:
                            continue
                        predictions = self.model(
                            src_inputs=src_inputs,
                            tgt_inputs=tgt_inputs,
                            src_pads=src_mask,
                            tgt_mask=tgt_mask,
                            src_langs=src_langs,
                            tgt_langs=dst_langs,
                            proposals=proposals,
                            pad_idx=model.text_processor.pad_token_id(),
                            log_softmax=True)
                        targets = tgt_inputs[:, 1:].contiguous().view(-1)
                        tgt_mask_flat = tgt_mask[:, 1:].contiguous().view(-1)
                        targets = targets[tgt_mask_flat]
                        ntokens = targets.size(0)
                    else:  # MASS data
                        src_inputs = batch["src_texts"].squeeze(0)
                        pad_indices = batch["pad_idx"].squeeze(0)
                        proposals = batch["proposal"].squeeze(
                            0) if lex_dict is not None else None
                        if src_inputs.size(0) < self.num_gpu:
                            continue

                        masked_info = mass_mask(self.mask_prob, pad_indices,
                                                src_inputs,
                                                model.text_processor)
                        predictions = self.model(
                            src_inputs=masked_info["src_text"],
                            tgt_inputs=masked_info["to_recover"],
                            tgt_positions=masked_info["positions"],
                            pad_idx=model.text_processor.pad_token_id(),
                            src_langs=batch["langs"].squeeze(0),
                            proposals=proposals,
                            log_softmax=True)
                        targets = masked_info["targets"]
                        ntokens = targets.size(0)

                    if is_contrastive:  # Nothing to predict!
                        backward(loss, self.optimizer, self.fp16)
                        loss = loss.data
                    elif ntokens > 0:
                        if self.num_gpu == 1:
                            targets = targets.to(predictions.device)
                        if self.rank >= 0: targets = targets.to(self.device)

                        loss = self.criterion(predictions, targets).mean()
                        backward(loss, self.optimizer, self.fp16)

                        loss = float(loss.data) * ntokens
                        tokens += ntokens
                        total_tokens += ntokens
                    total_loss += loss
                    cur_loss += loss

                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   self.clip)
                    step += 1
                    if step % accum == 0:
                        self.optimizer.step()
                        self.optimizer.zero_grad()

                    if is_mass_batch and not fine_tune:
                        mass_unmask(masked_info["src_text"],
                                    masked_info["src_mask"],
                                    masked_info["mask_idx"])
                    if not is_contrastive and is_img_batch and not fine_tune:
                        map(
                            lambda m: mass_unmask(m["src_text"], m["src_mask"],
                                                  m["mask_idx"]), masked_info)

                    if step % 50 == 0 and tokens > 0:
                        elapsed = time.time() - start
                        print(
                            self.rank, "->", datetime.datetime.now(),
                            "Epoch Step: %d Loss: %f Tokens per Sec: %f " %
                            (step, cur_loss / tokens, tokens / elapsed))

                        if mt_dev_iter is not None and step % 5000 == 0 and self.rank <= 0:
                            bleu = self.eval_bleu(mt_dev_iter, saving_path)
                            print("BLEU:", bleu)

                        if step % 10000 == 0:
                            if self.rank <= 0:
                                if self.rank < 0:
                                    model.cpu().save(saving_path + ".latest")
                                elif self.rank == 0:
                                    model.save(saving_path + ".latest")

                                if save_opt:
                                    with open(
                                            os.path.join(
                                                saving_path + ".latest",
                                                "optim"), "wb") as fp:
                                        pickle.dump(self.optimizer, fp)
                                if self.rank < 0:
                                    model = model.to(self.device)

                        start, tokens, cur_loss = time.time(), 0, 0

                except RuntimeError as err:
                    print(repr(err))
                    print("Error processing", is_img_batch)
                    if (isinstance(model, ImageMassSeq2Seq)) and is_img_batch:
                        for b in batch:
                            print("->", len(b["images"]), b["captions"].size())
                    torch.cuda.empty_cache()

            if i == shortest - 1:
                break
            if step >= max_step:
                break

        try:
            if self.rank <= 0:
                print("Total loss in this epoch: %f" %
                      (total_loss / total_tokens))
                if self.rank < 0:
                    model.cpu().save(saving_path + ".latest")
                    model = model.to(self.device)
                elif self.rank == 0:
                    model.save(saving_path + ".latest")

                if mt_dev_iter is not None:
                    bleu = self.eval_bleu(mt_dev_iter, saving_path)
                    print("BLEU:", bleu)
        except RuntimeError as err:
            print(repr(err))

        return step

    def get_batch_zip(self, img_data_iter, mass_data_iter, mt_train_iter):
        # if img_data_iter is not None and mt_train_iter is not None:
        #     img_data_iter *= 5
        # if mass_data_iter is not None and mt_train_iter is not None:
        #     mass_data_iter *= 5
        iters = list(
            chain(*filter(lambda x: x != None,
                          [img_data_iter, mass_data_iter, mt_train_iter])))
        shortest = min(len(l) for l in iters)
        return zip(*iters), shortest

    def eval_bleu(self, dev_data_iter, saving_path, save_opt: bool = False):
        mt_output = []
        src_text = []
        model = (self.model.module
                 if hasattr(self.model, "module") else self.model)
        model.eval()

        with torch.no_grad():
            for iter in dev_data_iter:
                for batch in iter:
                    src_inputs = batch["src_texts"].squeeze(0)
                    src_mask = batch["src_pad_mask"].squeeze(0)
                    tgt_inputs = batch["dst_texts"].squeeze(0)
                    src_langs = batch["src_langs"].squeeze(0)
                    dst_langs = batch["dst_langs"].squeeze(0)
                    src_pad_idx = batch["pad_idx"].squeeze(0)
                    proposal = batch["proposal"].squeeze(
                        0) if batch["proposal"] is not None else None

                    src_ids = get_outputs_until_eos(
                        model.text_processor.sep_token_id(),
                        src_inputs,
                        remove_first_token=True)
                    src_text += list(
                        map(
                            lambda src: model.text_processor.tokenizer.decode(
                                src.numpy()), src_ids))

                    outputs = self.generator(
                        src_inputs=src_inputs,
                        src_sizes=src_pad_idx,
                        first_tokens=tgt_inputs[:, 0],
                        src_mask=src_mask,
                        src_langs=src_langs,
                        tgt_langs=dst_langs,
                        pad_idx=model.text_processor.pad_token_id(),
                        proposals=proposal)
                    if self.num_gpu > 1 and self.rank < 0:
                        new_outputs = []
                        for output in outputs:
                            new_outputs += output
                        outputs = new_outputs

                    mt_output += list(
                        map(
                            lambda x: model.text_processor.tokenizer.decode(x[
                                1:].numpy()), outputs))

            model.train()
        bleu = sacrebleu.corpus_bleu(mt_output,
                                     [self.reference[:len(mt_output)]],
                                     lowercase=True,
                                     tokenize="intl")

        with open(os.path.join(saving_path, "bleu.output"), "w") as writer:
            writer.write("\n".join([
                src + "\n" + ref + "\n" + o + "\n\n***************\n"
                for src, ref, o in zip(src_text, mt_output,
                                       self.reference[:len(mt_output)])
            ]))

        if bleu.score > self.best_bleu:
            self.best_bleu = bleu.score
            print("Saving best BLEU", self.best_bleu)
            with open(os.path.join(saving_path, "bleu.best.output"),
                      "w") as writer:
                writer.write("\n".join([
                    src + "\n" + ref + "\n" + o + "\n\n***************\n"
                    for src, ref, o in zip(src_text, mt_output,
                                           self.reference[:len(mt_output)])
                ]))
            if self.rank < 0:
                model.cpu().save(saving_path)
                model = model.to(self.device)
            elif self.rank == 0:
                model.save(saving_path)

            if save_opt:
                with open(os.path.join(saving_path, "optim"), "wb") as fp:
                    pickle.dump(self.optimizer, fp)

        return bleu.score

    @staticmethod
    def train(options):
        lex_dict = None
        if options.dict_path is not None:
            lex_dict = get_lex_dict(options.dict_path)
        if options.local_rank <= 0 and not os.path.exists(options.model_path):
            os.makedirs(options.model_path)

        text_processor = TextProcessor(options.tokenizer_path)
        assert text_processor.pad_token_id() == 0
        num_processors = max(torch.cuda.device_count(),
                             1) if options.local_rank < 0 else 1

        if options.pretrained_path is not None:
            mt_model = Seq2Seq.load(ImageMassSeq2Seq,
                                    options.pretrained_path,
                                    tok_dir=options.tokenizer_path)
        else:
            mt_model = ImageMassSeq2Seq(
                use_proposals=lex_dict is not None,
                tie_embed=options.tie_embed,
                text_processor=text_processor,
                resnet_depth=options.resnet_depth,
                lang_dec=options.lang_decoder,
                enc_layer=options.encoder_layer,
                dec_layer=options.decoder_layer,
                embed_dim=options.embed_dim,
                intermediate_dim=options.intermediate_layer_dim)

        if options.lm_path is not None:
            lm = LM(text_processor=text_processor,
                    enc_layer=options.encoder_layer,
                    embed_dim=options.embed_dim,
                    intermediate_dim=options.intermediate_layer_dim)
            mt_model.init_from_lm(lm)

        print("Model initialization done!")

        # We assume that the collator function returns a list with the size of number of gpus (in case of cpus,
        collator = dataset.ImageTextCollator()
        num_batches = max(1, torch.cuda.device_count())

        if options.continue_train:
            with open(os.path.join(options.pretrained_path, "optim"),
                      "rb") as fp:
                optimizer = pickle.load(fp)
        else:
            optimizer = build_optimizer(mt_model,
                                        options.learning_rate,
                                        warump_steps=options.warmup)
        trainer = ImageMTTrainer(model=mt_model,
                                 mask_prob=options.mask_prob,
                                 optimizer=optimizer,
                                 clip=options.clip,
                                 beam_width=options.beam_width,
                                 max_len_a=options.max_len_a,
                                 max_len_b=options.max_len_b,
                                 len_penalty_ratio=options.len_penalty_ratio,
                                 fp16=options.fp16,
                                 mm_mode=options.mm_mode,
                                 rank=options.local_rank)

        pin_memory = torch.cuda.is_available()
        img_train_loader = ImageMTTrainer.get_img_loader(
            collator,
            dataset.ImageCaptionDataset,
            options.train_path,
            mt_model,
            num_batches,
            options,
            pin_memory,
            lex_dict=lex_dict)

        mass_train_data, mass_train_loader, finetune_loader, mt_dev_loader = None, None, None, None
        if options.mass_train_path is not None:
            mass_train_paths = options.mass_train_path.strip().split(",")
            if options.step > 0:
                mass_train_data, mass_train_loader = ImageMTTrainer.get_mass_loader(
                    mass_train_paths,
                    mt_model,
                    num_processors,
                    options,
                    pin_memory,
                    keep_examples=options.finetune_step > 0,
                    lex_dict=lex_dict)

            if options.finetune_step > 0:
                finetune_loader, finetune_data = ImageMTTrainer.get_mass_finetune_data(
                    mass_train_data,
                    mass_train_paths,
                    mt_model,
                    num_processors,
                    options,
                    pin_memory,
                    lex_dict=lex_dict)

        mt_train_loader = None
        if options.mt_train_path is not None:
            mt_train_loader = ImageMTTrainer.get_mt_train_data(
                mt_model,
                num_processors,
                options,
                pin_memory,
                lex_dict=lex_dict)

        mt_dev_loader = None
        if options.mt_dev_path is not None:
            mt_dev_loader = ImageMTTrainer.get_mt_dev_data(mt_model,
                                                           options,
                                                           pin_memory,
                                                           text_processor,
                                                           trainer,
                                                           lex_dict=lex_dict)

        step, train_epoch = 0, 1
        while options.step > 0 and step < options.step:
            print("train epoch", train_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader,
                                       mass_data_iter=mass_train_loader,
                                       mt_train_iter=mt_train_loader,
                                       max_step=options.step,
                                       lex_dict=lex_dict,
                                       mt_dev_iter=mt_dev_loader,
                                       saving_path=options.model_path,
                                       step=step,
                                       save_opt=options.save_opt,
                                       accum=options.accum)
            train_epoch += 1

        finetune_epoch = 0
        # Resetting the optimizer for the purpose of finetuning.
        trainer.optimizer.reset()

        lang_directions = ImageMTTrainer.get_lang_dirs(options.bt_langs,
                                                       text_processor)
        print(options.local_rank, "lang dirs", lang_directions)

        print(options.local_rank,
              "Reloading image train data with new batch size...")

        if options.finetune_step > 0 and img_train_loader is not None:
            img_train_loader = ImageMTTrainer.get_img_loader(
                collator,
                dataset.ImageCaptionDataset,
                options.train_path,
                mt_model,
                num_batches,
                options,
                pin_memory,
                denom=2,
                lex_dict=lex_dict)
        if options.ignore_mt_mass:
            mt_train_loader = None
        print(options.local_rank,
              "Reloading image train data with new batch size done!")

        while options.finetune_step > 0 and step <= options.finetune_step + options.step:
            print(options.local_rank, "finetune epoch", finetune_epoch)
            step = trainer.train_epoch(img_data_iter=img_train_loader,
                                       mass_data_iter=finetune_loader,
                                       mt_train_iter=mt_train_loader,
                                       max_step=options.finetune_step +
                                       options.step,
                                       mt_dev_iter=mt_dev_loader,
                                       saving_path=options.model_path,
                                       step=step,
                                       fine_tune=True,
                                       lang_directions=lang_directions,
                                       lex_dict=lex_dict,
                                       save_opt=options.save_opt,
                                       accum=options.accum,
                                       beam_width=options.bt_beam_width)
            finetune_epoch += 1

    @staticmethod
    def get_lang_dirs(bt_langs, text_processor: TextProcessor):
        langs = ["<" + l + ">" for l in bt_langs.strip().split(",")]
        langs = set([text_processor.token_id(l) for l in langs])
        if len(langs) < 2:
            return None
        assert len(langs) <= 2
        lang_directions = {}
        for lang1 in langs:
            for lang2 in langs:
                if lang1 != lang2:
                    # Assuming that we only have two languages!
                    lang_directions[lang1] = lang2
        return lang_directions

    @staticmethod
    def get_mt_dev_data(mt_model,
                        options,
                        pin_memory,
                        text_processor,
                        trainer,
                        lex_dict=None):
        mt_dev_loader = []
        dev_paths = options.mt_dev_path.split(",")
        trainer.reference = []
        for dev_path in dev_paths:
            mt_dev_data = dataset.MTDataset(
                batch_pickle_dir=dev_path,
                max_batch_capacity=options.total_capacity,
                keep_pad_idx=True,
                max_batch=int(options.batch / (options.beam_width * 2)),
                pad_idx=mt_model.text_processor.pad_token_id(),
                lex_dict=lex_dict)
            dl = data_utils.DataLoader(mt_dev_data,
                                       batch_size=1,
                                       shuffle=False,
                                       pin_memory=pin_memory)
            mt_dev_loader.append(dl)

            print(options.local_rank, "creating reference")

            generator = (trainer.generator.module if hasattr(
                trainer.generator, "module") else trainer.generator)

            for batch in dl:
                tgt_inputs = batch["dst_texts"].squeeze()
                refs = get_outputs_until_eos(text_processor.sep_token_id(),
                                             tgt_inputs,
                                             remove_first_token=True)
                ref = [
                    generator.seq2seq_model.text_processor.tokenizer.decode(
                        ref.numpy()) for ref in refs
                ]
                trainer.reference += ref
        return mt_dev_loader

    @staticmethod
    def get_mt_train_data(mt_model,
                          num_processors,
                          options,
                          pin_memory,
                          lex_dict=None):
        mt_train_loader = []
        train_paths = options.mt_train_path.split(",")
        for train_path in train_paths:
            mt_train_data = dataset.MTDataset(
                batch_pickle_dir=train_path,
                max_batch_capacity=int(num_processors *
                                       options.total_capacity / 2),
                max_batch=int(num_processors * options.batch / 2),
                pad_idx=mt_model.text_processor.pad_token_id(),
                lex_dict=lex_dict,
                keep_pad_idx=False)
            mtl = data_utils.DataLoader(
                mt_train_data,
                sampler=None if options.local_rank < 0 else DistributedSampler(
                    mt_train_data, rank=options.local_rank),
                batch_size=1,
                shuffle=(options.local_rank < 0),
                pin_memory=pin_memory)
            mt_train_loader.append(mtl)
        return mt_train_loader

    @staticmethod
    def get_mass_finetune_data(mass_train_data,
                               mass_train_paths,
                               mt_model,
                               num_processors,
                               options,
                               pin_memory,
                               lex_dict=None):
        finetune_data, finetune_loader = [], []
        for i, mass_train_path in enumerate(mass_train_paths):
            fd = dataset.MassDataset(
                batch_pickle_dir=mass_train_path,
                max_batch_capacity=int(num_processors *
                                       options.total_capacity /
                                       max(2, options.bt_beam_width)),
                max_batch=int(num_processors * options.batch /
                              max(2, options.bt_beam_width)),
                pad_idx=mt_model.text_processor.pad_token_id(),
                max_seq_len=options.max_seq_len,
                keep_examples=False,
                example_list=None if mass_train_data is None else
                mass_train_data[i].examples_list,
                lex_dict=lex_dict)
            finetune_data.append(fd)
            fl = data_utils.DataLoader(
                fd,
                sampler=None if options.local_rank < 0 else DistributedSampler(
                    fd, rank=options.local_rank),
                batch_size=1,
                shuffle=(options.local_rank < 0),
                pin_memory=pin_memory)
            finetune_loader.append(fl)
            if mass_train_data is not None:
                mass_train_data[i].examples_list = []
        return finetune_loader, finetune_data

    @staticmethod
    def get_mass_loader(mass_train_paths,
                        mt_model,
                        num_processors,
                        options,
                        pin_memory,
                        keep_examples,
                        lex_dict=None):
        mass_train_data, mass_train_loader = [], []
        for i, mass_train_path in enumerate(mass_train_paths):
            td = dataset.MassDataset(
                batch_pickle_dir=mass_train_path,
                max_batch_capacity=num_processors * options.total_capacity,
                max_batch=num_processors * options.batch,
                pad_idx=mt_model.text_processor.pad_token_id(),
                max_seq_len=options.max_seq_len,
                keep_examples=keep_examples,
                lex_dict=lex_dict)
            mass_train_data.append(td)

            dl = data_utils.DataLoader(
                td,
                sampler=None if options.local_rank < 0 else DistributedSampler(
                    td, rank=options.local_rank),
                batch_size=1,
                shuffle=(options.local_rank < 0),
                pin_memory=pin_memory)
            mass_train_loader.append(dl)
        return mass_train_data, mass_train_loader

    @staticmethod
    def get_img_loader(collator,
                       dataset_class,
                       paths,
                       mt_model,
                       num_batches,
                       options,
                       pin_memory,
                       denom=1,
                       lex_dict=None,
                       shuffle=True):
        if paths is not None:
            img_loader = []
            for pth in paths.strip().split(","):
                data = dataset_class(
                    root_img_dir=options.image_dir,
                    data_bin_file=pth,
                    max_capacity=int(options.img_capacity / denom),
                    text_processor=mt_model.text_processor,
                    max_img_per_batch=options.max_image / denom,
                    lex_dict=lex_dict)
                print(options.local_rank, pth, "Length of training data",
                      len(data))
                tl = data_utils.DataLoader(
                    data,
                    sampler=None if options.local_rank < 0 else
                    DistributedSampler(data, rank=options.local_rank),
                    batch_size=num_batches,
                    shuffle=shuffle,
                    pin_memory=pin_memory,
                    collate_fn=collator)
                img_loader.append(tl)
            return img_loader

        return None
Exemple #6
0
def main_tr(args, crossVal):
    dataLoad = ld.LoadData(args.data_dir, args.classes)
    data = dataLoad.processData(crossVal, args.data_name)

    # load the model
    model = net.MiniSeg(args.classes, aux=True)
    if not osp.isdir(osp.join(args.savedir + '_mod' + str(args.max_epochs))):
        os.mkdir(args.savedir + '_mod' + str(args.max_epochs))
    if not osp.isdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name)):
        os.mkdir(
            osp.join(args.savedir + '_mod' + str(args.max_epochs),
                     args.data_name))
    saveDir = args.savedir + '_mod' + str(
        args.max_epochs) + '/' + args.data_name + '/' + args.model_name
    # create the directory if not exist
    if not osp.exists(saveDir):
        os.mkdir(saveDir)

    if args.gpu and torch.cuda.device_count() > 1:
        #model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)
    if args.gpu:
        model = model.cuda()

    total_paramters = sum([np.prod(p.size()) for p in model.parameters()])
    print('Total network parameters: ' + str(total_paramters))

    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    if args.gpu:
        weight = weight.cuda()

    criteria = CrossEntropyLoss2d(weight, args.ignore_label)  #weight
    if args.gpu and torch.cuda.device_count() > 1:
        criteria = DataParallelCriterion(criteria)
    if args.gpu:
        criteria = criteria.cuda()

    # compose the data with transforms
    trainDataset_main = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale1 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.5), int(args.height * 1.5)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    trainDataset_scale2 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 1.25), int(args.height * 1.25)),
        myTransforms.RandomCropResize(int(100. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])
    trainDataset_scale3 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(int(args.width * 0.75), int(args.height * 0.75)),
        myTransforms.RandomCropResize(int(32. / 1024. * args.width)),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor()
    ])

    valDataset = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(args.width, args.height),
        myTransforms.ToTensor()
    ])

    # since we training from scratch, we create data loaders at different scales
    # so that we can generate more augmented data and prevent the network from overfitting
    trainLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['trainIm'], data['trainAnnot'], transform=trainDataset_main),
                                              batch_size=args.batch_size,
                                              shuffle=True,
                                              num_workers=args.num_workers,
                                              pin_memory=True,
                                              drop_last=True)

    trainLoader_scale1 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale1),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    trainLoader_scale2 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale2),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)
    trainLoader_scale3 = torch.utils.data.DataLoader(
        myDataLoader.Dataset(data['trainIm'],
                             data['trainAnnot'],
                             transform=trainDataset_scale3),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True,
        drop_last=True)

    valLoader = torch.utils.data.DataLoader(myDataLoader.Dataset(
        data['valIm'], data['valAnnot'], transform=valDataset),
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=True)
    max_batches = len(trainLoader) + len(trainLoader_scale1) + len(
        trainLoader_scale2) + len(trainLoader_scale3)

    if args.gpu:
        cudnn.benchmark = True

    start_epoch = 0

    if args.pretrained is not None:
        state_dict = torch.load(args.pretrained)
        new_keys = []
        new_values = []
        for idx, key in enumerate(state_dict.keys()):
            if 'pred' not in key:
                new_keys.append(key)
                new_values.append(list(state_dict.values())[idx])
        new_dict = OrderedDict(list(zip(new_keys, new_values)))
        model.load_state_dict(new_dict, strict=False)
        print('pretrained model loaded')

    if args.resume is not None:
        if osp.isfile(args.resume):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            start_epoch = checkpoint['epoch']
            args.lr = checkpoint['lr']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    log_file = osp.join(saveDir, 'trainValLog_' + args.model_name + '.txt')
    if osp.isfile(log_file):
        logger = open(log_file, 'a')
    else:
        logger = open(log_file, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t%s\t\t%s\t%s\t%s\t%s\tlr" %
                     ('CrossVal', 'Epoch', 'Loss(Tr)', 'Loss(val)',
                      'mIOU (tr)', 'mIOU (val)'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=1e-4)
    maxmIOU = 0
    maxEpoch = 0
    print(args.model_name + '-CrossVal: ' + str(crossVal + 1))
    for epoch in range(start_epoch, args.max_epochs):
        # train for one epoch
        cur_iter = 0

        train(args, trainLoader_scale1, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale1)
        train(args, trainLoader_scale2, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale2)
        train(args, trainLoader_scale3, model, criteria, optimizer, epoch,
              max_batches, cur_iter)
        cur_iter += len(trainLoader_scale3)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr, lr = \
                train(args, trainLoader, model, criteria, optimizer, epoch, max_batches, cur_iter)

        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = \
                val(args, valLoader, model, criteria)

        torch.save(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lossTr': lossTr,
                'lossVal': lossVal,
                'iouTr': mIOU_tr,
                'iouVal': mIOU_val,
                'lr': lr
            },
            osp.join(
                saveDir, 'checkpoint_' + args.model_name + '_crossVal' +
                str(crossVal + 1) + '.pth.tar'))

        # save the model also
        model_file_name = osp.join(
            saveDir, 'model_' + args.model_name + '_crossVal' +
            str(crossVal + 1) + '_' + str(epoch + 1) + '.pth')
        torch.save(model.state_dict(), model_file_name)

        logger.write(
            "\n%d\t\t%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" %
            (crossVal + 1, epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("\nEpoch No. %d:\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f\n" \
                % (epoch + 1, lossTr, lossVal, mIOU_tr, mIOU_val))

        if mIOU_val >= maxmIOU:
            maxmIOU = mIOU_val
            maxEpoch = epoch + 1
        torch.cuda.empty_cache()
    logger.flush()
    logger.close()
    return maxEpoch, maxmIOU