def __init__(self, opt):
        super().__init__()
        self.opt = opt
        # Intilaize dataset
        self.dataset = CaptionDataset(opt)
        opt.vocab_size = self.dataset.vocab_size
        opt.seq_length = self.dataset.seq_length
        self.batch_size = opt.batch_size

        # Build model
        opt.vocab = self.dataset.get_vocab()
        model = models.setup(opt)
        print(model)
        del opt.vocab

        # wrapper with loss in it.
        lw_model = LossWrapper(model, opt)

        self.model = model
        self.lw_model = lw_model

        self.struc_flag = None
        self.sc_flag = None
def test_folder():
    x = pickle_load(open('log_trans/infos_trans.pkl', 'rb'))
    dataset = CaptionDataset(x['opt'])
    ds = torch.utils.data.Subset(dataset, dataset.split_ix['train'])
    ds[0]
class LitModel(pl.LightningModule):
    def __init__(self, opt):
        super().__init__()
        self.opt = opt
        # Intilaize dataset
        self.dataset = CaptionDataset(opt)
        opt.vocab_size = self.dataset.vocab_size
        opt.seq_length = self.dataset.seq_length
        self.batch_size = opt.batch_size

        # Build model
        opt.vocab = self.dataset.get_vocab()
        model = models.setup(opt)
        print(model)
        del opt.vocab

        # wrapper with loss in it.
        lw_model = LossWrapper(model, opt)

        self.model = model
        self.lw_model = lw_model

        self.struc_flag = None
        self.sc_flag = None

    def forward(self, *args, **kwargs):
        """
        I hate this design. Never pretend it as a nn.Module
        """
        raise NotImplementedError

    def train_dataloader(self):
        train_dataset = torch.utils.data.Subset(self.dataset,
                                                self.dataset.split_ix['train'])

        train_loader = torch.utils.data.DataLoader(
            dataset=train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4,
            collate_fn=self.dataset.collate_func)
        return train_loader

    def val_dataloader(self, split='val'):
        val_dataset = torch.utils.data.Subset(self.dataset,
                                              self.dataset.split_ix[split])
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4,
            drop_last=False,
            collate_fn=self.dataset.collate_func)
        return val_loader

    def test_dataloader(self):
        return self.val_dataloader('test')

    def training_step(self, data, batch_idx):
        sc_flag, struc_flag, drop_worst_flag = self.sc_flag, self.struc_flag, self.drop_worst_flag

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        if int(os.getenv('M2_cider', '0')) != 0:
            data['gts'] = data['rawgts']
        model_out = self.lw_model(fc_feats, att_feats, labels, masks,
                                  att_masks, data['gts'],
                                  torch.arange(0, len(data['gts'])), sc_flag,
                                  struc_flag, drop_worst_flag)
        if not drop_worst_flag:
            loss = model_out.pop('loss').mean()
        else:
            loss = model_out.pop('loss')
            loss = torch.topk(loss,
                              k=int(loss.shape[0] *
                                    (1 - self.opt.drop_worst_rate)),
                              largest=False)[0].mean()

        # Prepare for logging info
        data_time = self.trainer.profiler.recorded_durations[
            "get_train_batch"][-1]
        data_time = torch.tensor(data_time)

        logger_logs = model_out.copy()
        if struc_flag or sc_flag:
            logger_logs['reward'] = model_out['reward'].mean()
            logger_logs['reward_var'] = model_out['reward'].var(1).mean()

        logger_logs['scheduled_sampling_prob'] = torch.tensor(
            self.model.ss_prob)
        logger_logs['training_loss'] = loss
        logger_logs['data_time'] = data_time

        for k, v in logger_logs.items():
            self.log(k,
                     v,
                     on_epoch=(k == 'training_loss'),
                     prog_bar=(k == 'data_time'))
        # logged

        return loss

    def validation_step(self, data, batch_idx):
        model = self.model
        crit = self.lw_model.crit

        opt = self.opt
        eval_kwargs = {'dataset': opt.input_json}
        eval_kwargs.update(vars(opt))

        verbose = eval_kwargs.get('verbose', True)
        verbose_beam = eval_kwargs.get('verbose_beam', 0)
        verbose_loss = eval_kwargs.get('verbose_loss', 1)
        # num_images = eval_kwargs.get('num_images', eval_kwargs.get('val_images_use', -1))
        # lang_eval = eval_kwargs.get('language_eval', 0)
        dataset = eval_kwargs.get('dataset', 'coco')
        beam_size = eval_kwargs.get('beam_size', 1)
        sample_n = eval_kwargs.get('sample_n', 1)
        remove_bad_endings = eval_kwargs.get('remove_bad_endings', 0)
        # Use this nasty way to make other code clean since it's a global configuration
        os.environ["REMOVE_BAD_ENDINGS"] = str(remove_bad_endings)

        predictions = []
        n_predictions = []

        loss = torch.tensor(0)

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        if data.get('labels', None) is not None and verbose_loss:
            # forward the model to get loss
            loss = crit(
                model(fc_feats, att_feats, labels[..., :-1], att_masks),
                labels[..., 1:], masks[..., 1:])

        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp_eval_kwargs = eval_kwargs.copy()
        tmp_eval_kwargs.update({'sample_n': 1})
        seq, seq_logprobs = model(fc_feats,
                                  att_feats,
                                  att_masks,
                                  opt=tmp_eval_kwargs,
                                  mode='sample')
        seq = seq.data
        entropy = -(F.softmax(seq_logprobs, dim=2) *
                    seq_logprobs).sum(2).sum(1) / (
                        (seq > 0).to(seq_logprobs).sum(1) + 1)
        perplexity = - \
            seq_logprobs.gather(2, seq.unsqueeze(2)).squeeze(
                2).sum(1) / ((seq > 0).to(seq_logprobs).sum(1)+1)

        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(fc_feats.shape[0]):
                print('\n'.join([
                    utils.decode_sequence(model.vocab,
                                          _['seq'].unsqueeze(0))[0]
                    for _ in model.done_beams[i]
                ]))
                print('--' * 10)
        sents = utils.decode_sequence(model.vocab, seq)

        for k, sent in enumerate(sents):
            entry = {
                'image_id': data['infos'][k]['id'],
                'caption': sent,
                'perplexity': perplexity[k].item(),
                'entropy': entropy[k].item()
            }
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(eval_kwargs['image_root'], data['infos'][k]['file_path']) + \
                    '" vis/imgs/img' + \
                    str(len(predictions)) + '.jpg'  # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        if sample_n > 1:
            eval_utils.eval_split_n(model, n_predictions,
                                    [fc_feats, att_feats, att_masks, data],
                                    eval_kwargs)

        output = {
            'val_loss': loss,
            'predictions': predictions,
            'n_predictions': n_predictions,
        }
        self.log('val_loss', loss)
        return output

    def test_step(self, *args, **kwargs):
        return self.validation_step(*args, **kwargs)

    def validation_epoch_end(self, outputs, split='val'):
        outputs = d2comm.gather(outputs)
        # master node
        if d2comm.is_main_process():
            assert self.trainer.node_rank == 0 and self.trainer.local_rank == 0
            outputs = sum(outputs, [])

            opt = self.opt
            val_loss_mean = sum([_['val_loss'].item()
                                 for _ in outputs]) / len(outputs)

            predictions = sum([_['predictions'] for _ in outputs], [])
            if len(outputs[0]['n_predictions']) != 0:
                n_predictions = sum([_['n_predictions'] for _ in outputs], [])
            else:
                n_predictions = []

            lang_stats = None
            if len(n_predictions) > 0 and 'perplexity' in n_predictions[0]:
                n_predictions = sorted(n_predictions,
                                       key=lambda x: x['perplexity'])

            if not os.path.isdir('eval_results'):
                os.mkdir('eval_results')
            torch.save(
                (predictions, n_predictions),
                os.path.join('eval_results/',
                             '.saved_pred_' + opt.id + '_' + split + '.pth'))

            if opt.language_eval:
                lang_stats = eval_utils.language_eval(opt.input_json,
                                                      predictions,
                                                      n_predictions, vars(opt),
                                                      split)

            if opt.reduce_on_plateau:
                optimizer = self.trainer.optimizers[0]
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss_mean)

            out = {'val_loss': val_loss_mean}
            out.update(lang_stats)
            out['to_monitor'] = lang_stats[
                'CIDEr'] if lang_stats is not None else -val_loss_mean
        else:
            out = {}

        out = d2comm.all_gather(out)[0]  # Only the one from master node
        assert len(out) > 0  # make sure the head has index 0

        # must all be tensors
        out = {
            k: torch.tensor(v) if not torch.is_tensor(v) else v
            for k, v in out.items()
        }
        for k, v in out.items():
            self.log(k, v)

        return out

    def test_epoch_end(self, outputs):
        out = self.validation_epoch_end(outputs, 'test')
        out['test_loss'] = out['val_loss']
        del out['val_loss']
        del out['to_monitor']
        out = {
            'test_' + k if 'test' not in k else k: v
            for k, v in out.items()
        }
        return out

    def configure_optimizers(self):
        opt = self.opt
        model = self.model
        if opt.noamopt:
            # assert opt.caption_model in ['transformer', 'bert', 'm2transformer'], 'noamopt can only work with transformer'
            optimizer = utils.get_std_opt(model,
                                          optim_func=opt.optim,
                                          factor=opt.noamopt_factor,
                                          warmup=opt.noamopt_warmup)
        elif opt.reduce_on_plateau:
            optimizer = utils.build_optimizer(model.parameters(), opt)
            optimizer = utils.ReduceLROnPlateau(
                optimizer,
                factor=opt.reduce_on_plateau_factor,
                patience=opt.reduce_on_plateau_patience)
        else:
            optimizer = utils.build_optimizer(model.parameters(), opt)
        return [optimizer], []

    def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, *args,
                       **kwargs):
        # warm up lr
        opt = self.opt
        iteration = self.trainer.global_step
        if opt.use_warmup and (iteration < opt.noamopt_warmup):
            opt.current_lr = opt.learning_rate * \
                (iteration+1) / opt.noamopt_warmup
            utils.set_lr(optimizer, opt.current_lr)

        super().optimizer_step(epoch, batch_idx, optimizer, optimizer_idx,
                               *args, **kwargs)

    def state_dict(self, *args, **kwargs):
        """
        Save the model state dict as well as opt and vocab
        """
        state_dict = self.model.state_dict(*args, **kwargs)
        device = next(iter(state_dict.values())).device
        assert '_vocab' not in state_dict and '_opt' not in state_dict, 'Just in case'
        state_dict.update({
            '_vocab':
            utils.serialize_to_tensor(self.model.vocab).to(device),
            '_opt':
            utils.serialize_to_tensor(self.opt).to(device)
        })
        return state_dict

    def load_state_dict(self, state_dict=None, strict=True):
        if '_vocab' in state_dict:
            self.model.vocab = utils.deserialize(state_dict['_vocab'])
            del state_dict['_vocab']
        elif strict:
            raise KeyError
        if '_opt' in state_dict:
            saved_model_opt = utils.deserialize(state_dict['_opt'])
            del state_dict['_opt']
            opt = self.opt
            # Make sure the saved opt is compatible with the curren topt
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                if getattr(saved_model_opt, checkme) in ['updown', 'topdown'] and \
                        getattr(opt, checkme) in ['updown', 'topdown']:
                    continue
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s' " % checkme
        elif strict:
            raise KeyError
        self.model.load_state_dict(state_dict, strict)

    def get_progress_bar_dict(self):
        # don't show the version number
        items = super().get_progress_bar_dict()
        items.pop("v_num", None)
        return items
def test_lmdb():
    x = pickle_load(open('log_trans/infos_trans.pkl', 'rb'))
    x['opt'].input_att_dir = 'data/vilbert_att.lmdb'
    dataset = CaptionDataset(x['opt'])
    ds = torch.utils.data.Subset(dataset, dataset.split_ix['train'])
    ds[0]