Beispiel #1
0
def compute_output(caption_model, beam_size, model, fc_feats, att_feats,
                   bu_feats):
    if models.is_only_fc_feat(caption_model):
        output = model.sample(fc_feats, {'beam_size': beam_size})
    elif models.is_only_att_feat(caption_model):
        output = model.sample(att_feats, {'beam_size': beam_size})
    elif models.has_bu(caption_model) or models.has_sub_region_bu(
            caption_model) or models.is_prob_weight_mul_out(caption_model):
        output = model.sample(fc_feats, att_feats, bu_feats,
                              {'beam_size': beam_size})
    else:
        output = model.sample(fc_feats, att_feats, {'beam_size': beam_size})
    return output
Beispiel #2
0
def get_criterion():
    # crit = Criterion.LanguageModelWeightNewCriterion()
    if models.is_mul_out_with_weight(opt.caption_model):
        crit = Criterion.LanguageModelWeightMulOutWithWeightCriterion(
            opt.prob_weight_alpha)
    elif models.is_mul_out(opt.caption_model):
        crit = Criterion.LanguageModelWeightMulOutCriterion()
    elif models.is_prob_weight(opt.caption_model):
        crit = Criterion.LanguageModelWithProbWeightCriterion(
            opt.prob_weight_alpha)
    elif models.is_prob_weight_mul_out(opt.caption_model):
        crit = Criterion.LanguageModelWithProbWeightMulOutCriterion(
            opt.prob_weight_alpha)
    else:
        crit = Criterion.LanguageModelWeightCriterion()
    return crit
Beispiel #3
0
def compute_cnn_feats(caption_model, model_cnn, images):

    fc_feats = None
    att_feats = None
    bu_feats = None

    if models.is_only_fc_feat(caption_model):
        fc_feats = model_cnn(images)
    elif models.is_only_att_feat(caption_model):
        att_feats = model_cnn(images)
    elif caption_model == "SCST":
        fc_feats, att_feats = model_cnn(images)
    elif models.is_prob_weight(caption_model):
        if models.has_sub_region_bu(caption_model):
            fc_feats, att_feats, bu_feats = model_cnn(images)
        else:
            fc_feats, att_feats = model_cnn(images)
    elif models.is_prob_weight_mul_out(caption_model):
        fc_feats, att_feats = model_cnn(images)
    else:
        fc_feats, att_feats = model_cnn(images)

    return fc_feats, att_feats, bu_feats
Beispiel #4
0
def compute_loss(crit, model, caption_model, seq_per_img, fc_expander,
                 att_expander, bu_expander, fc_feats, att_feats, bu_feats,
                 labels, masks, tokens):

    if models.is_only_fc_feat(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
        else:
            fc_feats_ext = fc_feats
        batch_outputs = model(fc_feats_ext, labels)
    elif models.is_only_att_feat(caption_model):
        if seq_per_img > 1:
            att_feats_ext = att_expander(att_feats)
        else:
            att_feats_ext = att_feats
        batch_outputs = model(att_feats_ext, labels)
    elif caption_model == "SCST":
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats
        batch_outputs, _ = model(fc_feats_ext, att_feats_ext, labels, "train")
    elif models.is_prob_weight(caption_model):
        if models.has_sub_region_bu(caption_model):
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
                bu_feats_ext = bu_expander(bu_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats
                bu_feats_ext = bu_feats

            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats

            if models.has_bu(caption_model):
                if seq_per_img > 1:
                    bu_feats_ext = bu_expander(bu_feats)
                else:
                    bu_feats_ext = bu_feats
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              bu_feats_ext, labels)
            else:
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              labels)
    elif models.is_prob_weight_mul_out(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext, labels)
    else:
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs = model(fc_feats_ext, att_feats_ext, bu_feats_ext,
                                  labels)
        else:
            batch_outputs = model(fc_feats_ext, att_feats_ext, labels)

    if models.is_prob_weight(caption_model) or models.is_prob_weight_mul_out(
            caption_model):
        loss = crit(batch_outputs, labels, masks, prob_w, tokens)
    else:
        loss = crit(batch_outputs, labels, masks)
    loss.backward()

    return loss.data[0]
Beispiel #5
0
def train_model(params, iteration, epoch, board):

    vocab = params['vocab']

    if opt.reinforce_start >= 0 and epoch >= opt.reinforce_start:
        use_reinforce = True

        # create crit
        if params['crit_pg'] is None:
            params['crit_pg'] = Criterion.PGCriterion(opt)
        if params['crit_rl'] is None:
            if opt.is_aic_data:
                if models.is_prob_weight_mul_out(opt.caption_model):
                    crit_rl = Criterion.RewardMulOutCriterionAIC(opt, vocab)
                else:
                    crit_rl = Criterion.RewardCriterionAIC(opt, vocab)
            else:
                crit_rl = Criterion.RewardCriterion(opt)
            params['crit_rl'] = crit_rl

        train_loss, reward_mean, sample_mean, greedy_mean = train_utils.train_reinforce(
            params, opt)

        if opt.use_tensorboard:
            if iteration % opt.tensorboard_for_train_every == 0:
                board.val("sample_mean", sample_mean, iteration)
                board.val("greedy_mean", greedy_mean, iteration)

    elif opt.mix_start >= 0 and epoch >= opt.mix_start:
        raise Exception('mix is deprecated')
        if params['crit_pg'] is None:
            params['crit_pg'] = Criterion.PGCriterion(opt)
        if params['crit'] is None:
            params['crit'] = get_criterion()
        train_loss, reward_mean = train_utils.train_mix(params, iteration, opt)
    elif opt.ctc_start >= 0 and epoch >= opt.ctc_start:

        use_reinforce = False
        if params['crit_ctc'] is None:
            params['crit_ctc'] = Criterion.CTCCriterion()
        train_loss, reward_mean = train_utils.train_normal(params, opt)

    elif opt.rl_critic_start >= 0 and epoch >= opt.rl_critic_start:
        use_reinforce = True
        if params['crit_c'] is None:
            params['crit_c'] = Criterion.ActorCriticMSECriterionAIC(opt, vocab)
        train_loss, reward_mean, sample_mean = train_utils.train_actor_critic(
            params, opt, 0)
        if opt.use_tensorboard:
            if iteration % opt.tensorboard_for_train_every == 0:
                board.val("reward_mean", reward_mean, iteration)
                board.val("sample_mean", sample_mean, iteration)
    elif opt.rl_actor_critic_start >= 0 and epoch >= opt.rl_actor_critic_start:
        use_reinforce = True
        if params['crit_c'] is None:
            params['crit_c'] = Criterion.ActorCriticMSECriterionAIC(opt, vocab)
        if params['crit_ac'] is None:
            params['crit_ac'] = Criterion.ActorCriticCriterionAIC(opt, vocab)
        train_loss1, reward_mean1, sample_mean1 = train_utils.train_actor_critic(
            params, opt, 0, retain_graph=True)
        print("critic loss: {:.3f} reward_mean: {:.3f}  sample_mean: {:.3f} ".
              format(train_loss1, reward_mean1, sample_mean1))
        train_loss, reward_mean, sample_mean = train_utils.train_actor_critic(
            params, opt, 1)
        print(
            "actor critic loss: {:.3f} reward_mean: {:.3f}  sample_mean: {:.3f} "
            .format(train_loss, reward_mean, sample_mean))
        if opt.use_tensorboard:
            if iteration % opt.tensorboard_for_train_every == 0:
                board.val("reward_mean", reward_mean, iteration)
                board.val("sample_mean", sample_mean, iteration)
    else:
        use_reinforce = False
        if params['crit'] is None:
            params['crit'] = get_criterion()
        if models.is_prob_weight(
                opt.caption_model) or models.is_prob_weight_mul_out(
                    opt.caption_model):
            train_loss, reward_mean = train_utils.train_with_prob_weight(
                params, opt)
        else:
            train_loss, reward_mean = train_utils.train_normal(params, opt)

    return train_loss, reward_mean, use_reinforce