Exemplo n.º 1
0
def load(model, opt):
    # check compatibility if training is continued from previously saved model
    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        utils.load_state_dict(
            model, torch.load(os.path.join(opt.start_from, 'model.pth')))
Exemplo n.º 2
0
    def __init__(self, opt):
        super(JointModel, self).__init__()
        self.opt = opt
        self.use_word_weights = getattr(opt, 'use_word_weights', 0)

        self.caption_generator = setup(opt, opt.caption_model, True)

        if opt.vse_model != 'None':
            self.vse = setup(opt, opt.vse_model, False)
            self.share_embed = opt.share_embed
            self.share_fc = opt.share_fc
            if self.share_embed:
                self.vse.txt_enc.embed = self.caption_generator.embed
            if self.share_fc:
                assert self.vse.embed_size == self.caption_generator.input_encoding_size
                if hasattr(self.caption_generator, 'img_embed'):
                    self.vse.img_enc.fc = self.caption_generator.img_embed
                else:
                    self.vse.img_enc.fc = self.caption_generator.att_embed
        else:
            self.vse = lambda x, y, z, w, u: Variable(torch.zeros(1)).cuda()

        if opt.vse_loss_weight == 0 and isinstance(self.vse, nn.Module):
            for p in self.vse.parameters():
                p.requires_grad = False

        self.vse_loss_weight = opt.vse_loss_weight
        self.caption_loss_weight = opt.caption_loss_weight

        self.retrieval_reward = opt.retrieval_reward  # none, reinforce, gumbel
        self.retrieval_reward_weight = opt.retrieval_reward_weight  #

        self.reinforce_baseline_type = getattr(opt, 'reinforce_baseline_type',
                                               'greedy')

        self.only_one_retrieval = getattr(opt, 'only_one_retrieval', 'off')

        self.cider_optimization = getattr(opt, 'cider_optimization', 0)

        self._loss = {}

        load(self, opt)
        if getattr(opt, 'initialize_retrieval', None) is not None:
            print("Make sure the vse opt are the same !!!!!\n" * 100)
            utils.load_state_dict(
                self, {
                    k: v
                    for k, v in torch.load(opt.initialize_retrieval).items()
                    if 'vse.' in k
                })
Exemplo n.º 3
0
            assert vars(opt)[k] == vars(
                infos['opt'])[k], k + ' option not consistent'
        else:
            vars(opt).update({k: vars(infos['opt'])[k]
                              })  # copy over options from model

vocab = infos['vocab']  # ix -> word mapping

assert opt.seq_per_img == 5

opt.vse_loss_weight = vars(opt).get('vse_loss_weight', 1)
opt.caption_loss_weight = vars(opt).get('caption_loss_weight', 1)

# Setup the model
model = models.JointModel(opt)
utils.load_state_dict(model, torch.load(opt.model))
if opt.initialize_retrieval is not None:
    print("Make sure the vse opt are the same !!!!!\n" * 100)
    utils.load_state_dict(
        model, {
            k: v
            for k, v in torch.load(opt.initialize_retrieval).items()
            if 'vse' in k
        })
model.cuda()
model.eval()

# Create the Data Loader instance
if len(opt.image_folder) == 0:
    loader = DataLoader(opt)
else:
Exemplo n.º 4
0
def eval(opt, model_name, infos_name, annFile, listener, split, iteration):
    # Input arguments and options
    # Load infos

    with open(infos_name, 'rb') as f:
        infos = cPickle.load(f, encoding='latin1')

    # For the case that we run eval not immediately after train, so arguments
    # are not exist. 'att_hid_size' is just one possible test to find out.
    if not hasattr(opt, 'att_hid_size'):
        opt = infos['opt']
    opt.split = split
    opt.beam_size = 2

    np.random.seed(123)

    # override and collect parameters
    if len(opt.input_fc_dir) == 0:
        opt.input_fc_dir = infos['opt'].input_fc_dir
        opt.input_att_dir = infos['opt'].input_att_dir
        opt.input_label_h5 = infos['opt'].input_label_h5
    if len(opt.input_json) == 0:
        opt.input_json = infos['opt'].input_json
    if opt.batch_size == 0:
        opt.batch_size = infos['opt'].batch_size
    if len(opt.id) == 0:
        opt.id = infos['opt'].id
    # if opt.initialize_retrieval == None:
    #     opt.initialize_retrieval = infos['opt'].initialize_retrieval
    ignore = [
        "id", "batch_size", "beam_size", "start_from", "language_eval",
        "initialize_retrieval", 'decoding_constraint', 'evaluation_retrieval',
        "input_fc_dir", "input_att_dir", "input_label_h5", 'seq_per_img',
        'closest_num', 'closest_file'
    ]
    # for k in vars(infos['opt']).keys():
    #     if k not in ignore:
    #         if k in vars(opt) and getattr(opt, k) is not None:
    #             assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent:' + str(vars(opt)[k])+' '+ str(vars(infos['opt'])[k])
    #         else:
    #             vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model

    vocab = infos['vocab']  # ix -> word mapping

    # assert opt.closest_num == opt.seq_per_img
    opt.vse_loss_weight = vars(opt).get('vse_loss_weight', 1)
    opt.caption_loss_weight = vars(opt).get('caption_loss_weight', 1)

    opt.cider_optimization = 0

    # Setup the model
    model = models.AlternatingJointModel(opt, iteration)
    # model = models.JointModel(opt)
    utils.load_state_dict(model, torch.load(model_name))
    if listener == 'gt':
        print('gt listener is loaded for evaluation')
        # utils.load_state_dict(model.vse, torch.load(opt.initialize_retrieval))
        utils.load_state_dict(
            model, {
                k: v
                for k, v in torch.load(opt.initialize_retrieval).items()
                if 'vse.' in k
            })

    model.cuda()
    model.eval()

    # Create the Data Loader instance
    loader = DataLoader(opt)
    # Set sample options
    loss, split_predictions, lang_stats = eval_utils.eval_split(
        model, loader, vars(opt), annFile, useGenSent=True)

    return {
        'loss': loss,
        'split_predictions': split_predictions,
        'lang_stats': lang_stats
    }
    def __init__(self, opt, iteration=None):
        super(AlternatingJointModel, self).__init__()
        self.opt = opt
        self.use_word_weights = getattr(opt, 'use_word_weights', 0)

        # self.caption_generator = setup(opt, opt.caption_model, True)
        self.caption_generator = setup(opt, opt.caption_model, 'caption_model')

        if opt.vse_model != 'None':
            # self.vse = setup(opt, opt.vse_model, False)
            self.vse = setup(opt, opt.vse_model, 'vse_model')
            self.share_embed = opt.share_embed
            if self.share_embed:
                self.caption_generator.embed[0] = self.vse.txt_enc.embed
                if self.opt.phase == 2:  # second phase (MLE) only
                    for p in self.caption_generator.embed.parameters():
                        p.requires_grad = False
        else:
            if torch.cuda.is_available():
                self.vse = lambda x, y, z, w, u: Variable(torch.zeros(1)).cuda(
                )
            else:  # CPU()
                self.vse = lambda x, y, z, w, u: Variable(torch.zeros(1))

        if opt.retrieval_reward == 'reinforce':
            if opt.vse_loss_weight == 0 and isinstance(self.vse, nn.Module):
                for p in self.vse.parameters():
                    p.requires_grad = False

        self.batch_size = opt.batch_size
        self.vse_loss_weight = opt.vse_loss_weight
        self.caption_loss_weight = opt.caption_loss_weight
        self.df = getattr(opt, 'df', 'coco-val')
        # none, reinforce, gumbel, multinomial
        self.retrieval_reward = opt.retrieval_reward
        # In case of training listener after training speaker with
        # reinforce_speaker. optimization named reinforce_listener
        # in run_joint.sh
        if not opt.alternating_turn == None:
            if len(opt.alternating_turn) == 1 and \
                    opt.retrieval_reward == 'reinforce':
                if opt.alternating_turn[0] == 'listener':
                    opt.retrieval_reward_weight = 0
        self.retrieval_reward_weight = opt.retrieval_reward_weight  #

        self.reinforce_baseline_type = getattr(opt, 'reinforce_baseline_type',
                                               'greedy')
        self.sheriff_baseline_type = getattr(opt, 'sheriff_baseline_type',
                                             'greedy')

        self.only_one_retrieval = getattr(opt, 'only_one_retrieval', 'off')

        self.cider_optimization = getattr(opt, 'cider_optimization', 0)

        self.use_gen_cider_scores = getattr(opt, 'use_gen_cider_scores', 0)

        self._loss = {}

        # Load model
        if opt.is_alternating:  # In case of alternating training
            if opt.continue_from_existing_models:
                # If we already have a previous version of the
                # alternating model.
                if os.path.isfile(
                        os.path.join(opt.start_from, 'alternatingModel.pth')):
                    if iteration:  # For evaluation choose specific iteration
                        old_alternating_model_path = os.path.join(
                            opt.start_from,
                            'alternatingModel-' + iteration + '.pth')
                    else:
                        old_alternating_model_path = os.path.join(
                            opt.start_from, 'alternatingModel.pth')
                    if torch.cuda.is_available():
                        utils.load_state_dict(
                            self, torch.load(old_alternating_model_path))
                    else:
                        utils.load_state_dict(
                            self,
                            torch.load(old_alternating_model_path,
                                       map_location='cpu'))
                    print('Loaded alternating model from {}'.format(
                        old_alternating_model_path))
                else:  # initialize from stage 2 model
                    # load pre- trained speaker model from stage 2
                    old_speaker_path = opt.speaker_stage_2_model_path
                    if torch.cuda.is_available():
                        utils.load_state_dict(self,
                                              torch.load(old_speaker_path))
                    else:  # CPU()
                        utils.load_state_dict(
                            self,
                            torch.load(old_speaker_path, map_location='cpu'))
                    print(f'Loaded pre-trained "speaker" model, '
                          f'after stage 2 from {old_speaker_path}')
        else:  # No alternating

            load(self, opt, iteration)
            if getattr(opt, 'initialize_retrieval', None) is not None:
                print("Make sure the vse opt are the same !!!!!")
                if torch.cuda.is_available():
                    utils.load_state_dict(
                        self, {
                            k: v
                            for k, v in torch.load(opt.initialize_retrieval).
                            items() if 'vse.' in k
                        })
                else:  # CPU
                    utils.load_state_dict(
                        self, {
                            k: v
                            for k, v in torch.load(opt.initialize_retrieval,
                                                   map_location='cpu').items()
                            if 'vse.' in k
                        })
Exemplo n.º 6
0
        if k in vars(opt) and getattr(opt, k) is not None:
            assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent:' + str(vars(opt)[k])+' '+ str(vars(infos['opt'])[k])
        else:
            vars(opt).update({k: vars(infos['opt'])[k]}) # copy over options from model

vocab = infos['vocab'] # ix -> word mapping

assert opt.closest_num == opt.seq_per_img
opt.vse_loss_weight = vars(opt).get('vse_loss_weight', 1)
opt.caption_loss_weight = vars(opt).get('caption_loss_weight', 1)

opt.cider_optimization = 0

# Setup the model
model = models.JointModel(opt)
utils.load_state_dict(model, torch.load(opt.model))

model.cuda()
model.eval()

# Create the Data Loader instance
if len(opt.image_folder) == 0:
  loader = DataLoader(opt)
else:
  loader = DataLoaderRaw({'folder_path': opt.image_folder, 
                            'coco_json': opt.coco_json,
                            'batch_size': opt.batch_size,
                            'cnn_model': opt.cnn_model})
  loader.ix_to_word = infos['vocab']

def eval_split(model, crit, loader, eval_kwargs={}):
    verbose = eval_kwargs.get('verbose', True)
    verbose_beam = eval_kwargs.get('verbose_beam', 1)
    verbose_loss = eval_kwargs.get('verbose_loss', 1)
    num_images = eval_kwargs.get('num_images',
                                 eval_kwargs.get('val_images_use', -1))
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 0)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)

    if eval_kwargs.get('rank', 0):
        infos_path = 'log_fc_con/infos_vse_fc_con-best.pkl'  # 'log_fc_con_discsplit/infos_vse_fc_con_discsplit-best.pkl'
        model_path = 'log_fc_con/model_vse-best.pth'  # 'log_fc_con_discsplit/model_vse-best.pth'
        with open(infos_path) as f:
            infos = cPickle.load(f)

        rank_model = models.JointModel(infos['opt'])
        utils.load_state_dict(rank_model, torch.load(model_path))
        rank_model.cuda()
        rank_model.eval()
        print('success loaded retrieval model !')

    # Make sure in the evaluation mode
    model.eval()

    loader.reset_iterator(split)

    n = 0
    loss = 0
    loss_sum = 0
    loss_evals = 1e-8
    predictions = []
    seqs = []
    while True:
        data = loader.get_batch(split)
        n = n + loader.batch_size
        sys.stdout.flush()
        if data.get('labels', None) is not None and verbose_loss:
            # forward the model to get loss
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [
                torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp
            ]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            with torch.no_grad():
                loss = crit(model(fc_feats, att_feats, labels, att_masks),
                            labels[:, 1:], masks[:, 1:]).item()
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1

        # forward the model to also get generated samples for each image
        # Only leave one feature for each image, in case duplicate sample
        tmp = [
            data['fc_feats'][np.arange(loader.batch_size) *
                             loader.seq_per_img],
            data['att_feats'][np.arange(loader.batch_size) *
                              loader.seq_per_img],
            data['att_masks'][np.arange(loader.batch_size) *
                              loader.seq_per_img]
            if data['att_masks'] is not None else None
        ]
        tmp = [torch.from_numpy(_).cuda() if _ is not None else _ for _ in tmp]
        fc_feats, att_feats, att_masks = tmp
        # forward the model to also get generated samples for each image
        with torch.no_grad():
            seq = model(fc_feats,
                        att_feats,
                        att_masks,
                        opt=eval_kwargs,
                        mode='sample')[0].data

        # Print beam search
        if beam_size > 1 and verbose_beam:
            for i in range(loader.batch_size):
                pass  #print('\n'.join([utils.decode_sequence(loader.get_vocab(), _['seq'].unsqueeze(0))[0] for _ in model.done_beams[i]]))
                #print('--' * 10)
        sents = utils.decode_sequence(loader.get_vocab(), seq)

        for k, sent in enumerate(sents):
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
            if eval_kwargs.get('dump_path', 0) == 1:
                entry['file_name'] = data['infos'][k]['file_path']
            predictions.append(entry)
            if eval_kwargs.get('dump_images', 0) == 1:
                # dump the raw image to vis/ folder
                cmd = 'cp "' + os.path.join(
                    eval_kwargs['image_root'],
                    data['infos'][k]['file_path']) + '" vis/imgs/img' + str(
                        len(predictions)) + '.jpg'  # bit gross
                print(cmd)
                os.system(cmd)

            if verbose:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        # if we wrapped around the split or used up val imgs budget then bail
        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if eval_kwargs.get('rank', 0):
            seqs.append(padding(seq, 30))
        if num_images != -1:
            ix1 = min(ix1, num_images)
        for i in range(n - ix1):
            predictions.pop()
        if n > ix1:
            seq = seq[:(ix1 - n) * loader.seq_per_img]

        if verbose:
            print('evaluating validation preformance... %d/%d (%f)' %
                  (ix0 - 1, ix1, loss))

        if data['bounds']['wrapped']:
            break
        if num_images >= 0 and n >= num_images:
            break

    if eval_kwargs.get('rank', 0):
        seqs = torch.cat(seqs, 0).contiguous()
        seqs = change_seq(seqs, loader.ix_to_word)

    if eval_kwargs.get('vsepp', 0):
        from eval_vsepp import evalrank_vsepp
        from eval_utils_pair import get_transform
        import torchvision.transforms as transforms
        from PIL import Image

        imgids = [_['image_id'] for _ in predictions]
        seqs = seqs[:num_images]

        transform = get_transform('COCO', 'val', None)
        imgs = []
        for i, imgid in enumerate(imgids):
            img_path = '../imgcap/data/raw_images/val2014/COCO_val2014_' + str(
                imgid).zfill(12) + '.jpg'
            if i % 100 == 0:
                print('load %d images' % i)
            image = Image.open(img_path).convert('RGB')
            image = transform(image)
            imgs.append(image.unsqueeze(0))
        imgs = torch.cat(imgs, 0).contiguous()
        lengths = torch.sum((seqs > 0), 1) + 1
        lengths = lengths.cpu()
        with torch.no_grad():
            evalrank_vsepp(imgs, loader.ix_to_word, seqs, lengths)

    lang_stats = None
    if lang_eval == 1:
        lang_stats = language_eval(eval_kwargs.get('data', 'coco'),
                                   predictions, eval_kwargs['id'], split)

    if eval_kwargs.get('rank', 0):
        ranks = evalrank(rank_model, loader, seqs, eval_kwargs)

    # Switch back to training mode
    model.train()
    return loss_sum / loss_evals, predictions, lang_stats