Пример #1
0
def train_cnn(model_cnn, images, bus, fc_expander, att_expander, bu_expander,
              use_reinforce):

    fc_feats = None
    att_feats = None
    bu_feats = None

    # train cnn
    if models.is_only_fc_feat(opt.caption_model):
        fc_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
    elif models.is_only_att_feat(opt.caption_model):
        att_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            att_feats = att_expander(att_feats)
    elif models.has_sub_region_bu(opt.caption_model):
        fc_feats, att_feats, bu_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
            att_feats = att_expander(att_feats)
            bu_feats = bu_expander(bu_feats)
    else:
        fc_feats, att_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
            att_feats = att_expander(att_feats)

    if models.has_bu(opt.caption_model):
        bus_feats = bus
        if opt.seq_per_img > 1 and not use_reinforce:
            bu_feats = bu_expander(bus_feats)

    return fc_feats, att_feats, bu_feats
Пример #2
0
def get_loader():
    if models.has_bu(opt.caption_model) or \
            models.has_sub_regions(opt.caption_model) or \
            models.has_sub_region_bu(opt.caption_model):
        loader = DataLoaderThreadBu(opt)
        print("DataLoaderThreadBu")
    else:
        loader = DataLoaderThreadNew(opt)
        print("DataLoaderThreadNew")
    return loader
Пример #3
0
def get_expander():
    fc_expander = None
    att_expander = None
    bu_expander = None
    if opt.seq_per_img > 1:
        fc_expander = utils.FeatExpander(opt.seq_per_img)
        att_expander = utils.FeatExpander(opt.seq_per_img)
        if models.has_bu(opt.caption_model) or models.has_sub_region_bu(
                opt.caption_model):
            bu_expander = utils.FeatExpander(opt.seq_per_img)
    return fc_expander, att_expander, bu_expander
Пример #4
0
def compute_output(caption_model, beam_size, model, fc_feats, att_feats,
                   bu_feats):
    if models.is_only_fc_feat(caption_model):
        output = model.sample(fc_feats, {'beam_size': beam_size})
    elif models.is_only_att_feat(caption_model):
        output = model.sample(att_feats, {'beam_size': beam_size})
    elif models.has_bu(caption_model) or models.has_sub_region_bu(
            caption_model) or models.is_prob_weight_mul_out(caption_model):
        output = model.sample(fc_feats, att_feats, bu_feats,
                              {'beam_size': beam_size})
    else:
        output = model.sample(fc_feats, att_feats, {'beam_size': beam_size})
    return output
def train_with_prob_weight(params, opt):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    targets = params['targets']
    masks = params['masks']
    tokens = params['tokens']
    crit = params['crit']

    # forward
    start = time.time()

    if models.is_transformer(opt.caption_model):
        output, prob_w = model(att_feats, targets, masks)
    elif models.is_ctransformer(opt.caption_model):
        output, prob_w = model(fc_feats, att_feats, targets, masks)
    elif models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        bu_feats = params['bu_feats']
        output, prob_w = model(fc_feats, att_feats, bu_feats, labels)
    else:
        output, prob_w = model(fc_feats, att_feats, labels)

    if opt.verbose:
        print('model {:.3f}'.format(time.time() - start))

    # compute the loss
    start = time.time()
    # input, target, mask, prob_w, token, alpha)
    loss = crit(output, labels, masks, prob_w, tokens)
    if opt.verbose:
        print('crit {:.3f}'.format(time.time() - start))

    # backward
    start = time.time()
    loss.backward()
    if opt.verbose:
        print('loss {:.3f}'.format(time.time() - start))

    # show information
    train_loss = loss.data[0]
    reward_mean = 0

    return train_loss, reward_mean
Пример #6
0
def compute_cnn_feats(caption_model, model_cnn, images):

    fc_feats = None
    att_feats = None
    bu_feats = None

    if models.is_only_fc_feat(caption_model):
        fc_feats = model_cnn(images)
    elif models.is_only_att_feat(caption_model):
        att_feats = model_cnn(images)
    elif caption_model == "SCST":
        fc_feats, att_feats = model_cnn(images)
    elif models.is_prob_weight(caption_model):
        if models.has_sub_region_bu(caption_model):
            fc_feats, att_feats, bu_feats = model_cnn(images)
        else:
            fc_feats, att_feats = model_cnn(images)
    elif models.is_prob_weight_mul_out(caption_model):
        fc_feats, att_feats = model_cnn(images)
    else:
        fc_feats, att_feats = model_cnn(images)

    return fc_feats, att_feats, bu_feats
def train_actor_critic(params, opt, type, retain_graph=False):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    masks = params['masks']
    vocab = params['vocab']
    gts = params['gts']

    if type == 0:
        crit_c = params['crit_c']
    elif type == 1:
        crit_ac = params['crit_ac']

    if models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        bu_feats = params['bu_feats']

    # forward
    start = time.time()
    if models.is_only_fc_feat(opt.caption_model):
        sample_seq, sample_seqLogprobs, sample_value = model.sample(
            fc_feats, {'sample_max': 0})
    elif models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        sample_seq, sample_seqLogprobs, sample_value = model.sample(
            fc_feats, att_feats, bu_feats, {'sample_max': 0})
    else:
        # sample_seq, sample_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 0})
        # greedy_seq, greedy_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 1})
        sample_output = model.sample(fc_feats, att_feats, {'sample_max': 0})

        sample_seq = sample_output[0]
        sample_seqLogprobs = sample_output[1]
        sample_value = sample_output[2]

    if opt.verbose:
        print('model {:.3f}'.format(time.time() - start))

    # compute the loss
    start = time.time()
    # 0. critic
    # 1. critic, actor
    if type == 0:
        # seq, seqLogprobs, seq1, target, vocab
        loss, reward_mean, sample_mean = crit_c(sample_seq, sample_value, gts)
    elif type == 1:
        # seq, seqLogprobs, seq1, target, vocab
        loss, reward_mean, sample_mean = crit_ac(sample_seq,
                                                 sample_seqLogprobs,
                                                 sample_value, gts)
    # loss, reward_mean = crit_rl(sample_seq, sample_seqLogprobs, gts)
    if opt.verbose:
        print('crit {:.3f}'.format(time.time() - start))

    # backward
    start = time.time()
    loss.backward(retain_graph=retain_graph)
    if opt.verbose:
        print('loss {:.3f}'.format(time.time() - start))

    # show information
    train_loss = loss.data[0]

    return train_loss, reward_mean, sample_mean
def train_reinforce(params, opt):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    masks = params['masks']
    vocab = params['vocab']
    crit_pg = params['crit_pg']
    crit_rl = params['crit_rl']
    targets = params['targets']
    gts = params['gts']

    if models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        bu_feats = params['bu_feats']

    # compute policy gradient
    if opt.reinforce_type == 0:
        raise Exception('reinforce_type error, 0 is deprecated')
        # forward
        start = time.time()
        if models.is_only_fc_feat(opt.caption_model):
            output = model(fc_feats, labels)
        else:
            output = model(fc_feats, att_feats, labels)

        if opt.verbose:
            print('model {:.3f}'.format(time.time() - start))

        train_loss, reward_mean = crit_pg.forward_backward(
            output, labels, masks, vocab)
    # self-critical
    elif opt.reinforce_type == 1:
        # forward
        start = time.time()
        if models.is_only_fc_feat(opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                fc_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                fc_feats, {'sample_max': 1})
        elif models.is_only_att_feat(opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                att_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                att_feats, {'sample_max': 1})
        elif models.has_bu(opt.caption_model) or models.has_sub_region_bu(
                opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                fc_feats, att_feats, bu_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                fc_feats, att_feats, bu_feats, {'sample_max': 1})
        else:
            # sample_seq, sample_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 0})
            # greedy_seq, greedy_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 1})
            sample_output = model.sample(fc_feats, att_feats,
                                         {'sample_max': 0})
            greedy_output = model.sample(fc_feats, att_feats,
                                         {'sample_max': 1})

            sample_seq = sample_output[0]
            sample_seqLogprobs = sample_output[1]

            greedy_seq = greedy_output[0]
            greedy_seqLogprobs = greedy_output[1]

        if opt.verbose:
            print('model {:.3f}'.format(time.time() - start))

        # compute the loss
        start = time.time()
        # seq, seqLogprobs, seq1, target, vocab
        loss, reward_mean, sample_mean, greedy_mean = crit_rl(
            sample_seq, sample_seqLogprobs, greedy_seq, gts, masks)
        # loss, reward_mean = crit_rl(sample_seq, sample_seqLogprobs, gts)
        if opt.verbose:
            print('crit {:.3f}'.format(time.time() - start))

        # backward
        start = time.time()
        loss.backward()
        if opt.verbose:
            print('loss {:.3f}'.format(time.time() - start))

        # show information
        train_loss = loss.data[0]

    return train_loss, reward_mean, sample_mean, greedy_mean
def get_batch_worker(split, loader, batch_size):

    # batch_size = loader.batch_size
    seq_per_img = loader.seq_per_img
    split_ix = loader.split_ix[split]

    # image_size = 256
    # image_size = 224

    if loader.use_image:
        if loader.opt.use_pre_feat:
            img_batch = torch.FloatTensor(batch_size, loader.opt.att_feat_size, loader.opt.pool_size, loader.opt.pool_size)
        else:
            img_batch = torch.FloatTensor(batch_size, 3, loader.image_size, loader.image_size)

    bu_batch = np.zeros([batch_size, loader.bu_size, loader.bu_feat_size], dtype='float32')

    label_batch = np.zeros([batch_size * seq_per_img, loader.seq_length+1], dtype='int')
    mask_batch = np.zeros([batch_size * seq_per_img, loader.seq_length+1], dtype='float32')
    token_batch = np.zeros([batch_size * seq_per_img, loader.vocab_size+1], dtype='int')
    boxes_batch = []

    max_index = len(split_ix)
    wrapped = False

    infos = []
    gts = []

    # start = time.time()
    for i in range(batch_size):

        ri = loader.iterators[split]
        ri_next = ri + 1
        if ri_next >= max_index:
            ri_next = 0
            wrapped = True
        loader.iterators[split] = ri_next
        ix = split_ix[ri]

        img_info = loader.info['images'][ix]

        if loader.use_image:
            if loader.opt.use_pre_feat:
                value = loader.lmdb.get(str(img_info['id']))
                data = np.load(io.BytesIO(value))
                Ir = data['x']
                img = torch.from_numpy(Ir)
            else:
                # fetch image
                I = imread(os.path.join(loader.images_root, img_info['file_path']))
                # handle grayscale input images
                if len(I.shape) == 2:
                    I = I[:, :, np.newaxis]
                    I = np.concatenate((I, I, I), axis=2)

                try:
                    if split == 'train':
                        span_width = random.randint(0, loader.img_padding_max)
                        Ir = imresize(I, (loader.image_size + span_width, loader.image_size + span_width))
                        rx = random.randint(0, span_width)
                        ry = random.randint(0, span_width)
                        Ir = Ir[rx: loader.image_size + rx, ry: loader.image_size + ry, :]
                    else:
                        Ir = imresize(I, (loader.image_size, loader.image_size))
                except:
                    print('failed resizing image %s - see http://git.io/vBIE0' % (img_info['file_path']))
                    raise

                # and swap order of axes from (256,256,3) to (3,256,256)
                Ir = Ir.astype('float32') / 255.0
                Ir = Ir.transpose(2, 0, 1)

                if split == 'train':
                    # random flip
                    # vis_conv.show_img(img.transpose(1,2,0))
                    if loader.opt.use_mirror and random.randint(0, 99) >= 50:
                        Ir = np.flip(Ir, 2).copy()
                    # vis_conv.show_img(img.transpose(1,2,0))

                img = torch.from_numpy(Ir)
                img = loader.preprocess(img)

            img_batch[i] = img

        if img_info.has_key("image_id"):
            image_id = str(img_info['image_id']) + ".jpg"
        else:
            image_id = str(img_info["id"])
        value = loader.lmdb_bu.get(image_id)
        data = np.load(io.BytesIO(value))
        data_x = data['x'].tolist()
        bu_batch[i] = data_x['features']
        boxes = data_x['boxes']

        np_boxes = np.array(boxes)
        img_h = float(I.shape[0])
        img_w = float(I.shape[1])

        # print(img_h, img_w)
        box_size = np.array([img_w, img_h, img_w, img_h])
        np_nboxes = np_boxes / box_size

        boxes_batch.append(np_nboxes)

        ix1 = loader.label_start_ix[ix] - 1
        ix2 = loader.label_end_ix[ix] - 1
        ncap = ix2 - ix1 + 1
        assert ncap > 0,'an image does not have any label. this can be handled but right now isn\'t'

        token = np.zeros([seq_per_img, loader.vocab_size + 1], dtype='int')
        if ncap < seq_per_img:
            seq = np.zeros([seq_per_img, loader.seq_length],dtype='int')
            for q in range(seq_per_img):
                ix1 = random.randint(ix1, ix2)
                seq[q,:] = loader.input_h5['labels'][ix1, :loader.seq_length]
        else:
            ix1 = random.randint(ix1, ix2 - seq_per_img + 1)
            seq = loader.input_h5['labels'][ix1: ix1 + seq_per_img, :loader.seq_length]

        for k in range(seq_per_img):
            token[k, seq[k]] = 1

        label_batch[i*seq_per_img: (i+1)*seq_per_img, :loader.seq_length] = seq
        token_batch[i*seq_per_img: (i+1)*seq_per_img] = token

        # Used for reward evaluation
        gts.append(loader.input_h5['labels'][loader.label_start_ix[ix] - 1: loader.label_end_ix[ix]])

        info_dict = {}
        info_dict['id'] = loader.info['images'][ix]['id']
        info_dict['file_path'] = loader.info['images'][ix]['file_path']
        if 'image_id' in loader.info['images'][ix].keys():
            info_dict['image_id'] = loader.info['images'][ix]['image_id']
        infos.append(info_dict)
    # print(time.time() - start)
    nonzeros = np.array(map(lambda x: (x != 0).sum()+1, label_batch))

    if loader.opt.loss_weight_type == 0:
        weight = np.linspace(loader.opt.loss_weight_start, loader.opt.loss_weight_stop, loader.seq_length+1)
        for ix, row in enumerate(mask_batch):
            mask_len = nonzeros[ix]
            row[:mask_len] = weight[:mask_len]
    elif loader.opt.loss_weight_type == 1:
        half_len = loader.seq_length//2
        weight = np.linspace(loader.opt.loss_weight_stop, loader.opt.loss_weight_start, half_len)
        weight1 = np.linspace(loader.opt.loss_weight_start, loader.opt.loss_weight_stop, half_len+1)
        for ix, row in enumerate(mask_batch):
            mask_len = nonzeros[ix]
            if mask_len <= half_len:
                row[:mask_len] = weight[:mask_len]
            else:
                row[:half_len] = weight[:half_len]
                row[half_len:mask_len] = weight1[:mask_len-half_len]
    else:
        for ix, row in enumerate(mask_batch):
            mask_len = nonzeros[ix]
            row[:mask_len] = 1


    if wrapped:
        random.shuffle(loader.split_ix[split])

    data = {}

    # images = torch.from_numpy(img_batch)
    # if split == 'train':
    #     data_augment = True
    # else:
    #     data_augment = False
    #
    #

    # if loader.opt.data_norm:
    #     images = utils.prepro_norm(images, data_augment)
    # else:
    #     images = utils.prepro(images, data_augment)

    # for inception
    # images.add_(loader.img_mean)
    # images.div_(loader.img_std)

    if loader.use_image:
        images = Variable(img_batch.cuda(), requires_grad=False)

    labels = torch.from_numpy(label_batch).cuda()
    labels = Variable(labels, requires_grad=False)

    masks = torch.from_numpy(mask_batch).cuda()
    masks = Variable(masks, requires_grad=False)

    tokens = torch.from_numpy(token_batch).cuda().float()
    tokens = Variable(tokens, requires_grad=False)

    bus = torch.from_numpy(bu_batch).cuda().float()
    bus = Variable(bus, requires_grad=False)

    if loader.use_image:
        if models.has_sub_regions(loader.opt.caption_model) or models.has_sub_region_bu(loader.opt.caption_model):
            data['images'] = {}
            data['images']['images'] = images
            data['images']['boxes'] = boxes_batch
        else:
            data['images'] = images

    data['labels'] = labels
    data['masks'] = masks
    data['tokens'] = tokens
    data['bus'] = bus
    data['gts'] = gts
    data['bounds'] = {'it_pos_now':loader.iterators[split], 'it_max':len(loader.split_ix[split]), 'wrapped': wrapped}
    data['infos'] = infos

    return data
Пример #10
0
def eval_split(model_cnn, model, crit, loader, eval_kwargs={}):

    verbose_eval = eval_kwargs.get('verbose_eval', True)
    val_images_use = eval_kwargs.get('val_images_use', -1)
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 1)
    beam_size = eval_kwargs.get('beam_size', 1)
    coco_caption_path = eval_kwargs.get('coco_caption_path', 'coco-caption')
    caption_model = eval_kwargs.get('caption_model', '')
    batch_size = eval_kwargs.get('batch_size', 2)
    seq_per_img = eval_kwargs.get('seq_per_img', 5)
    id = eval_kwargs.get('id', '')
    input_anno = eval_kwargs.get('input_anno', '')
    is_compute_val_loss = eval_kwargs.get('is_compute_val_loss', 0)

    # aic caption path
    is_aic_data = eval_kwargs.get('is_aic_data', False)
    aic_caption_path = eval_kwargs.get('aic_caption_path', 'aic-caption')

    # Make sure in the evaluation mode
    model_cnn.eval()
    model.eval()

    if crit is None:
        is_compute_val_loss = 0

    if is_compute_val_loss == 1 and seq_per_img > 1:
        fc_expander = utils.FeatExpander(seq_per_img)
        att_expander = utils.FeatExpander(seq_per_img)
        bu_expander = None
        if models.has_bu(caption_model) or models.has_sub_region_bu(
                caption_model):
            bu_expander = utils.FeatExpander(seq_per_img)

    loader.reset_iterator(split)

    n = 0
    loss_sum = 0
    loss_evals = 0
    predictions = []
    vocab = loader.get_vocab()
    vocab_size = loader.get_vocab_size()
    while True:

        start = time.time()

        data = loader.get_batch(split, batch_size)
        n = n + batch_size

        images = data['images']
        labels = data['labels']
        masks = data['masks']
        tokens = data['tokens']

        images.volatile = True
        labels.volatile = True
        masks.volatile = True
        tokens.volatile = True

        fc_feats, att_feats, bu_feats = compute_cnn_feats(
            caption_model, model_cnn, images)

        if models.has_bu(caption_model):
            bu_feats = data['bus']
            bu_feats.volatile = True

        if is_compute_val_loss == 1:
            loss = compute_loss(crit, model, caption_model, seq_per_img,
                                fc_expander, att_expander, bu_expander,
                                fc_feats, att_feats, bu_feats, labels, masks,
                                tokens)
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1
        else:
            loss = 0

        output = compute_output(caption_model, beam_size, model, fc_feats,
                                att_feats, bu_feats)

        seq = output[0]

        #
        if type(seq) == type([]):
            seq = seq[-1]
        if is_aic_data:
            sents = utils.decode_sequence_aic(vocab, seq)
            # captions = utils.decode_sequence(vocab, seq)
        else:
            sents = utils.decode_sequence(vocab, seq)
            # captions = utils.decode_sequence(vocab, seq)

        # print(sents)
        # print(captions)

        for k, sent in enumerate(sents):
            if is_aic_data:
                image_id = data['infos'][k]['image_id']
            else:
                image_id = data['infos'][k]['id']
            entry = {'image_id': image_id, 'caption': sent}
            # caption = {'image_id': image_id, 'caption': captions[k]}
            predictions.append(entry)
            if verbose_eval:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']

        if val_images_use != -1:
            ix1 = min(ix1, val_images_use)
        for i in range(n - ix1):
            predictions.pop()
        if verbose_eval:
            span_time = time.time() - start
            left_time = (ix1 - ix0) * span_time / batch_size
            s_left_time = utils.format_time(left_time)
            print('evaluating validation preformance... %d/%d %.3fs left:%s' %
                  (ix0 - 1, ix1, span_time, s_left_time))

        if data['bounds']['wrapped']:
            break
        if n >= val_images_use:
            break

    if lang_eval == 1:
        if is_aic_data:
            lang_stats, str_stats = language_eval_aic(id, predictions,
                                                      aic_caption_path,
                                                      input_anno)
        else:
            lang_stats, str_stats = language_eval(id, predictions,
                                                  coco_caption_path,
                                                  input_anno)

    # Switch back to training mode
    model_cnn.train()
    model.train()

    if is_compute_val_loss == 1:
        final_loss = loss_sum / loss_evals
    else:
        final_loss = 0

    return final_loss, predictions, lang_stats, str_stats
Пример #11
0
def compute_loss(crit, model, caption_model, seq_per_img, fc_expander,
                 att_expander, bu_expander, fc_feats, att_feats, bu_feats,
                 labels, masks, tokens):

    if models.is_only_fc_feat(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
        else:
            fc_feats_ext = fc_feats
        batch_outputs = model(fc_feats_ext, labels)
    elif models.is_only_att_feat(caption_model):
        if seq_per_img > 1:
            att_feats_ext = att_expander(att_feats)
        else:
            att_feats_ext = att_feats
        batch_outputs = model(att_feats_ext, labels)
    elif caption_model == "SCST":
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats
        batch_outputs, _ = model(fc_feats_ext, att_feats_ext, labels, "train")
    elif models.is_prob_weight(caption_model):
        if models.has_sub_region_bu(caption_model):
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
                bu_feats_ext = bu_expander(bu_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats
                bu_feats_ext = bu_feats

            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats

            if models.has_bu(caption_model):
                if seq_per_img > 1:
                    bu_feats_ext = bu_expander(bu_feats)
                else:
                    bu_feats_ext = bu_feats
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              bu_feats_ext, labels)
            else:
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              labels)
    elif models.is_prob_weight_mul_out(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext, labels)
    else:
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs = model(fc_feats_ext, att_feats_ext, bu_feats_ext,
                                  labels)
        else:
            batch_outputs = model(fc_feats_ext, att_feats_ext, labels)

    if models.is_prob_weight(caption_model) or models.is_prob_weight_mul_out(
            caption_model):
        loss = crit(batch_outputs, labels, masks, prob_w, tokens)
    else:
        loss = crit(batch_outputs, labels, masks)
    loss.backward()

    return loss.data[0]
Пример #12
0
def main():

    opt = parse_args()

    # make dirs
    print(opt.eval_result_path)
    if not os.path.isdir(opt.eval_result_path):
        os.makedirs(opt.eval_result_path)

    # Load infos
    infos = load_infos(opt)

    ignore = [
        "id", "input_json", "input_h5", "input_anno", "images_root",
        "coco_caption_path", "batch_size", "beam_size", "start_from_best",
        "eval_result_path"
    ]
    for k in vars(infos['opt']).keys():
        if k not in ignore:
            if k in vars(opt):
                assert vars(opt)[k] == vars(
                    infos['opt'])[k], k + ' option not consistent'
            else:
                vars(opt).update({k: vars(infos['opt'])[k]
                                  })  # copy over options from model

    # print(opt)

    # Setup the model
    model_cnn = models.setup_cnn(opt)
    model_cnn.cuda()

    model = models.setup(opt)
    model.cuda()

    # Make sure in the evaluation mode
    model_cnn.eval()
    model.eval()

    if models.has_bu(opt.caption_model) or \
            models.has_sub_regions(opt.caption_model) or \
            models.has_sub_region_bu(opt.caption_model):
        loader = DataLoaderThreadBu(opt)
        print("DataLoaderThreadBu")
    else:
        loader = DataLoaderThreadNew(opt)
        print("DataLoaderThreadNew")

    loader.ix_to_word = infos['vocab']

    eval_kwargs = {'split': opt.val_split, 'dataset': opt.input_json}
    eval_kwargs.update(vars(opt))

    start_beam = 0
    total_beam = 20
    for beam in range(start_beam, total_beam):
        opt.beam_size = beam + 1
        eval_kwargs.update(vars(opt))
        print("beam_size: " + str(opt.beam_size))
        print("start eval ...")
        crit = None
        val_loss, predictions, lang_stats, str_stats = eval_utils.eval_split(
            model_cnn, model, crit, loader, eval_kwargs)
        print("end eval ...")
        msg = "str_stats = {}".format(str_stats)
        print(msg)
        save_result(str(opt.beam_size) + "," + str_stats, predictions, opt)