Example #1
0
def train_cnn(model_cnn, images, bus, fc_expander, att_expander, bu_expander,
              use_reinforce):

    fc_feats = None
    att_feats = None
    bu_feats = None

    # train cnn
    if models.is_only_fc_feat(opt.caption_model):
        fc_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
    elif models.is_only_att_feat(opt.caption_model):
        att_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            att_feats = att_expander(att_feats)
    elif models.has_sub_region_bu(opt.caption_model):
        fc_feats, att_feats, bu_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
            att_feats = att_expander(att_feats)
            bu_feats = bu_expander(bu_feats)
    else:
        fc_feats, att_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
            att_feats = att_expander(att_feats)

    if models.has_bu(opt.caption_model):
        bus_feats = bus
        if opt.seq_per_img > 1 and not use_reinforce:
            bu_feats = bu_expander(bus_feats)

    return fc_feats, att_feats, bu_feats
Example #2
0
def compute_output(caption_model, beam_size, model, fc_feats, att_feats,
                   bu_feats):
    if models.is_only_fc_feat(caption_model):
        output = model.sample(fc_feats, {'beam_size': beam_size})
    elif models.is_only_att_feat(caption_model):
        output = model.sample(att_feats, {'beam_size': beam_size})
    elif models.has_bu(caption_model) or models.has_sub_region_bu(
            caption_model) or models.is_prob_weight_mul_out(caption_model):
        output = model.sample(fc_feats, att_feats, bu_feats,
                              {'beam_size': beam_size})
    else:
        output = model.sample(fc_feats, att_feats, {'beam_size': beam_size})
    return output
def train_normal(params, opt):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    targets = params['targets']
    masks = params['masks']
    vocab = params['vocab']
    crit = params['crit']

    # forward
    start = time.time()
    if models.is_transformer(opt.caption_model):
        output = model(att_feats, targets, masks)
    elif models.is_ctransformer(opt.caption_model):
        output = model(fc_feats, att_feats, targets, masks)
    elif models.is_only_fc_feat(opt.caption_model):
        output = model(fc_feats, labels)
    elif models.is_only_att_feat(opt.caption_model):
        output = model(att_feats, labels)
    elif models.has_bu(opt.caption_model):
        bu_feats = params['bu_feats']
        output = model(fc_feats, att_feats, bu_feats, labels)
    else:
        output = model(fc_feats, att_feats, labels)

    if opt.verbose:
        print('model {:.3f}'.format(time.time() - start))

    # compute the loss
    start = time.time()

    if models.is_prob_weight(opt.caption_model):
        output = output[0]

    loss = crit(output, labels, masks)
    if opt.verbose:
        print('crit {:.3f}'.format(time.time() - start))

    # backward
    start = time.time()
    loss.backward()
    if opt.verbose:
        print('loss {:.3f}'.format(time.time() - start))

    # show information
    train_loss = loss.data[0]
    reward_mean = 0

    return train_loss, reward_mean
Example #4
0
def compute_cnn_feats(caption_model, model_cnn, images):

    fc_feats = None
    att_feats = None
    bu_feats = None

    if models.is_only_fc_feat(caption_model):
        fc_feats = model_cnn(images)
    elif models.is_only_att_feat(caption_model):
        att_feats = model_cnn(images)
    elif caption_model == "SCST":
        fc_feats, att_feats = model_cnn(images)
    elif models.is_prob_weight(caption_model):
        if models.has_sub_region_bu(caption_model):
            fc_feats, att_feats, bu_feats = model_cnn(images)
        else:
            fc_feats, att_feats = model_cnn(images)
    elif models.is_prob_weight_mul_out(caption_model):
        fc_feats, att_feats = model_cnn(images)
    else:
        fc_feats, att_feats = model_cnn(images)

    return fc_feats, att_feats, bu_feats
def train_reinforce(params, opt):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    masks = params['masks']
    vocab = params['vocab']
    crit_pg = params['crit_pg']
    crit_rl = params['crit_rl']
    targets = params['targets']
    gts = params['gts']

    if models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        bu_feats = params['bu_feats']

    # compute policy gradient
    if opt.reinforce_type == 0:
        raise Exception('reinforce_type error, 0 is deprecated')
        # forward
        start = time.time()
        if models.is_only_fc_feat(opt.caption_model):
            output = model(fc_feats, labels)
        else:
            output = model(fc_feats, att_feats, labels)

        if opt.verbose:
            print('model {:.3f}'.format(time.time() - start))

        train_loss, reward_mean = crit_pg.forward_backward(
            output, labels, masks, vocab)
    # self-critical
    elif opt.reinforce_type == 1:
        # forward
        start = time.time()
        if models.is_only_fc_feat(opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                fc_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                fc_feats, {'sample_max': 1})
        elif models.is_only_att_feat(opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                att_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                att_feats, {'sample_max': 1})
        elif models.has_bu(opt.caption_model) or models.has_sub_region_bu(
                opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                fc_feats, att_feats, bu_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                fc_feats, att_feats, bu_feats, {'sample_max': 1})
        else:
            # sample_seq, sample_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 0})
            # greedy_seq, greedy_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 1})
            sample_output = model.sample(fc_feats, att_feats,
                                         {'sample_max': 0})
            greedy_output = model.sample(fc_feats, att_feats,
                                         {'sample_max': 1})

            sample_seq = sample_output[0]
            sample_seqLogprobs = sample_output[1]

            greedy_seq = greedy_output[0]
            greedy_seqLogprobs = greedy_output[1]

        if opt.verbose:
            print('model {:.3f}'.format(time.time() - start))

        # compute the loss
        start = time.time()
        # seq, seqLogprobs, seq1, target, vocab
        loss, reward_mean, sample_mean, greedy_mean = crit_rl(
            sample_seq, sample_seqLogprobs, greedy_seq, gts, masks)
        # loss, reward_mean = crit_rl(sample_seq, sample_seqLogprobs, gts)
        if opt.verbose:
            print('crit {:.3f}'.format(time.time() - start))

        # backward
        start = time.time()
        loss.backward()
        if opt.verbose:
            print('loss {:.3f}'.format(time.time() - start))

        # show information
        train_loss = loss.data[0]

    return train_loss, reward_mean, sample_mean, greedy_mean
Example #6
0
def eval_split_only(model_cnn, model, crit, loader, eval_kwargs={}):

    verbose_eval = eval_kwargs.get('verbose_eval', True)
    val_images_use = eval_kwargs.get('val_images_use', -1)
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 1)
    dataset = eval_kwargs.get('dataset', 'coco')
    beam_size = eval_kwargs.get('beam_size', 1)
    coco_caption_path = eval_kwargs.get('coco_caption_path', 'coco-caption')
    caption_model = eval_kwargs.get('caption_model', '')
    batch_size = eval_kwargs.get('batch_size', 2)
    seq_per_img = eval_kwargs.get('seq_per_img', 5)

    # Make sure in the evaluation mode
    model_cnn.eval()
    model.eval()

    loader.reset_iterator(split)

    n = 0
    loss_sum = 0
    loss_evals = 0
    predictions = []
    vocab = loader.get_vocab()
    vocab_size = loader.get_vocab_size()
    while True:
        data = loader.get_batch(split, batch_size)
        n = n + batch_size

        images = data['images']

        if models.is_only_fc_feat(caption_model):
            fc_feats = model_cnn(images)
        elif models.is_only_att_feat(caption_model):
            att_feats = model_cnn(images)
        elif caption_model == "SCST":
            fc_feats, att_feats = model_cnn(images)
        else:
            fc_feats, att_feats = model_cnn(images)

        if models.is_only_fc_feat(caption_model):
            seq, _ = model.sample(fc_feats, {'beam_size': beam_size})
        elif models.is_only_att_feat(caption_model):
            seq, _ = model.sample(att_feats, {'beam_size': beam_size})
        else:
            seq, _ = model.sample(fc_feats, att_feats,
                                  {'beam_size': beam_size})

        #
        sents = utils.decode_sequence(vocab, seq)

        for k, sent in enumerate(sents):
            entry = {'image_id': data['infos'][k]['id'], 'caption': sent}
            predictions.append(entry)
            if verbose_eval:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']

        if val_images_use != -1:
            ix1 = min(ix1, val_images_use)
        for i in range(n - ix1):
            predictions.pop()
        if verbose_eval:
            print('evaluating validation preformance... %d/%d' %
                  (ix0 - 1, ix1))

        if data['bounds']['wrapped']:
            break
        if n >= val_images_use:
            break

    if lang_eval == 1:
        lang_stats, str_stats = language_eval(dataset, predictions,
                                              coco_caption_path)

    # Switch back to training mode
    model_cnn.train()
    model.train()

    return 0, predictions, lang_stats, str_stats
Example #7
0
def compute_loss(crit, model, caption_model, seq_per_img, fc_expander,
                 att_expander, bu_expander, fc_feats, att_feats, bu_feats,
                 labels, masks, tokens):

    if models.is_only_fc_feat(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
        else:
            fc_feats_ext = fc_feats
        batch_outputs = model(fc_feats_ext, labels)
    elif models.is_only_att_feat(caption_model):
        if seq_per_img > 1:
            att_feats_ext = att_expander(att_feats)
        else:
            att_feats_ext = att_feats
        batch_outputs = model(att_feats_ext, labels)
    elif caption_model == "SCST":
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats
        batch_outputs, _ = model(fc_feats_ext, att_feats_ext, labels, "train")
    elif models.is_prob_weight(caption_model):
        if models.has_sub_region_bu(caption_model):
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
                bu_feats_ext = bu_expander(bu_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats
                bu_feats_ext = bu_feats

            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats

            if models.has_bu(caption_model):
                if seq_per_img > 1:
                    bu_feats_ext = bu_expander(bu_feats)
                else:
                    bu_feats_ext = bu_feats
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              bu_feats_ext, labels)
            else:
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              labels)
    elif models.is_prob_weight_mul_out(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext, labels)
    else:
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs = model(fc_feats_ext, att_feats_ext, bu_feats_ext,
                                  labels)
        else:
            batch_outputs = model(fc_feats_ext, att_feats_ext, labels)

    if models.is_prob_weight(caption_model) or models.is_prob_weight_mul_out(
            caption_model):
        loss = crit(batch_outputs, labels, masks, prob_w, tokens)
    else:
        loss = crit(batch_outputs, labels, masks)
    loss.backward()

    return loss.data[0]