示例#1
0
def train_cnn(model_cnn, images, bus, fc_expander, att_expander, bu_expander,
              use_reinforce):

    fc_feats = None
    att_feats = None
    bu_feats = None

    # train cnn
    if models.is_only_fc_feat(opt.caption_model):
        fc_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
    elif models.is_only_att_feat(opt.caption_model):
        att_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            att_feats = att_expander(att_feats)
    elif models.has_sub_region_bu(opt.caption_model):
        fc_feats, att_feats, bu_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
            att_feats = att_expander(att_feats)
            bu_feats = bu_expander(bu_feats)
    else:
        fc_feats, att_feats = model_cnn(images)
        if opt.seq_per_img > 1 and not use_reinforce:
            fc_feats = fc_expander(fc_feats)
            att_feats = att_expander(att_feats)

    if models.has_bu(opt.caption_model):
        bus_feats = bus
        if opt.seq_per_img > 1 and not use_reinforce:
            bu_feats = bu_expander(bus_feats)

    return fc_feats, att_feats, bu_feats
示例#2
0
def get_loader():
    if models.has_bu(opt.caption_model) or \
            models.has_sub_regions(opt.caption_model) or \
            models.has_sub_region_bu(opt.caption_model):
        loader = DataLoaderThreadBu(opt)
        print("DataLoaderThreadBu")
    else:
        loader = DataLoaderThreadNew(opt)
        print("DataLoaderThreadNew")
    return loader
示例#3
0
def get_expander():
    fc_expander = None
    att_expander = None
    bu_expander = None
    if opt.seq_per_img > 1:
        fc_expander = utils.FeatExpander(opt.seq_per_img)
        att_expander = utils.FeatExpander(opt.seq_per_img)
        if models.has_bu(opt.caption_model) or models.has_sub_region_bu(
                opt.caption_model):
            bu_expander = utils.FeatExpander(opt.seq_per_img)
    return fc_expander, att_expander, bu_expander
示例#4
0
def compute_output(caption_model, beam_size, model, fc_feats, att_feats,
                   bu_feats):
    if models.is_only_fc_feat(caption_model):
        output = model.sample(fc_feats, {'beam_size': beam_size})
    elif models.is_only_att_feat(caption_model):
        output = model.sample(att_feats, {'beam_size': beam_size})
    elif models.has_bu(caption_model) or models.has_sub_region_bu(
            caption_model) or models.is_prob_weight_mul_out(caption_model):
        output = model.sample(fc_feats, att_feats, bu_feats,
                              {'beam_size': beam_size})
    else:
        output = model.sample(fc_feats, att_feats, {'beam_size': beam_size})
    return output
def train_normal(params, opt):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    targets = params['targets']
    masks = params['masks']
    vocab = params['vocab']
    crit = params['crit']

    # forward
    start = time.time()
    if models.is_transformer(opt.caption_model):
        output = model(att_feats, targets, masks)
    elif models.is_ctransformer(opt.caption_model):
        output = model(fc_feats, att_feats, targets, masks)
    elif models.is_only_fc_feat(opt.caption_model):
        output = model(fc_feats, labels)
    elif models.is_only_att_feat(opt.caption_model):
        output = model(att_feats, labels)
    elif models.has_bu(opt.caption_model):
        bu_feats = params['bu_feats']
        output = model(fc_feats, att_feats, bu_feats, labels)
    else:
        output = model(fc_feats, att_feats, labels)

    if opt.verbose:
        print('model {:.3f}'.format(time.time() - start))

    # compute the loss
    start = time.time()

    if models.is_prob_weight(opt.caption_model):
        output = output[0]

    loss = crit(output, labels, masks)
    if opt.verbose:
        print('crit {:.3f}'.format(time.time() - start))

    # backward
    start = time.time()
    loss.backward()
    if opt.verbose:
        print('loss {:.3f}'.format(time.time() - start))

    # show information
    train_loss = loss.data[0]
    reward_mean = 0

    return train_loss, reward_mean
def train_with_prob_weight(params, opt):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    targets = params['targets']
    masks = params['masks']
    tokens = params['tokens']
    crit = params['crit']

    # forward
    start = time.time()

    if models.is_transformer(opt.caption_model):
        output, prob_w = model(att_feats, targets, masks)
    elif models.is_ctransformer(opt.caption_model):
        output, prob_w = model(fc_feats, att_feats, targets, masks)
    elif models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        bu_feats = params['bu_feats']
        output, prob_w = model(fc_feats, att_feats, bu_feats, labels)
    else:
        output, prob_w = model(fc_feats, att_feats, labels)

    if opt.verbose:
        print('model {:.3f}'.format(time.time() - start))

    # compute the loss
    start = time.time()
    # input, target, mask, prob_w, token, alpha)
    loss = crit(output, labels, masks, prob_w, tokens)
    if opt.verbose:
        print('crit {:.3f}'.format(time.time() - start))

    # backward
    start = time.time()
    loss.backward()
    if opt.verbose:
        print('loss {:.3f}'.format(time.time() - start))

    # show information
    train_loss = loss.data[0]
    reward_mean = 0

    return train_loss, reward_mean
def train_actor_critic(params, opt, type, retain_graph=False):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    masks = params['masks']
    vocab = params['vocab']
    gts = params['gts']

    if type == 0:
        crit_c = params['crit_c']
    elif type == 1:
        crit_ac = params['crit_ac']

    if models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        bu_feats = params['bu_feats']

    # forward
    start = time.time()
    if models.is_only_fc_feat(opt.caption_model):
        sample_seq, sample_seqLogprobs, sample_value = model.sample(
            fc_feats, {'sample_max': 0})
    elif models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        sample_seq, sample_seqLogprobs, sample_value = model.sample(
            fc_feats, att_feats, bu_feats, {'sample_max': 0})
    else:
        # sample_seq, sample_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 0})
        # greedy_seq, greedy_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 1})
        sample_output = model.sample(fc_feats, att_feats, {'sample_max': 0})

        sample_seq = sample_output[0]
        sample_seqLogprobs = sample_output[1]
        sample_value = sample_output[2]

    if opt.verbose:
        print('model {:.3f}'.format(time.time() - start))

    # compute the loss
    start = time.time()
    # 0. critic
    # 1. critic, actor
    if type == 0:
        # seq, seqLogprobs, seq1, target, vocab
        loss, reward_mean, sample_mean = crit_c(sample_seq, sample_value, gts)
    elif type == 1:
        # seq, seqLogprobs, seq1, target, vocab
        loss, reward_mean, sample_mean = crit_ac(sample_seq,
                                                 sample_seqLogprobs,
                                                 sample_value, gts)
    # loss, reward_mean = crit_rl(sample_seq, sample_seqLogprobs, gts)
    if opt.verbose:
        print('crit {:.3f}'.format(time.time() - start))

    # backward
    start = time.time()
    loss.backward(retain_graph=retain_graph)
    if opt.verbose:
        print('loss {:.3f}'.format(time.time() - start))

    # show information
    train_loss = loss.data[0]

    return train_loss, reward_mean, sample_mean
def train_reinforce(params, opt):

    model = params['model']
    fc_feats = params['fc_feats']
    att_feats = params['att_feats']
    labels = params['labels']
    masks = params['masks']
    vocab = params['vocab']
    crit_pg = params['crit_pg']
    crit_rl = params['crit_rl']
    targets = params['targets']
    gts = params['gts']

    if models.has_bu(opt.caption_model) or models.has_sub_region_bu(
            opt.caption_model):
        bu_feats = params['bu_feats']

    # compute policy gradient
    if opt.reinforce_type == 0:
        raise Exception('reinforce_type error, 0 is deprecated')
        # forward
        start = time.time()
        if models.is_only_fc_feat(opt.caption_model):
            output = model(fc_feats, labels)
        else:
            output = model(fc_feats, att_feats, labels)

        if opt.verbose:
            print('model {:.3f}'.format(time.time() - start))

        train_loss, reward_mean = crit_pg.forward_backward(
            output, labels, masks, vocab)
    # self-critical
    elif opt.reinforce_type == 1:
        # forward
        start = time.time()
        if models.is_only_fc_feat(opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                fc_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                fc_feats, {'sample_max': 1})
        elif models.is_only_att_feat(opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                att_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                att_feats, {'sample_max': 1})
        elif models.has_bu(opt.caption_model) or models.has_sub_region_bu(
                opt.caption_model):
            sample_seq, sample_seqLogprobs = model.sample(
                fc_feats, att_feats, bu_feats, {'sample_max': 0})
            greedy_seq, greedy_seqLogprobs = model.sample(
                fc_feats, att_feats, bu_feats, {'sample_max': 1})
        else:
            # sample_seq, sample_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 0})
            # greedy_seq, greedy_seqLogprobs = model.sample_forward(fc_feats, att_feats, labels, {'sample_max': 1})
            sample_output = model.sample(fc_feats, att_feats,
                                         {'sample_max': 0})
            greedy_output = model.sample(fc_feats, att_feats,
                                         {'sample_max': 1})

            sample_seq = sample_output[0]
            sample_seqLogprobs = sample_output[1]

            greedy_seq = greedy_output[0]
            greedy_seqLogprobs = greedy_output[1]

        if opt.verbose:
            print('model {:.3f}'.format(time.time() - start))

        # compute the loss
        start = time.time()
        # seq, seqLogprobs, seq1, target, vocab
        loss, reward_mean, sample_mean, greedy_mean = crit_rl(
            sample_seq, sample_seqLogprobs, greedy_seq, gts, masks)
        # loss, reward_mean = crit_rl(sample_seq, sample_seqLogprobs, gts)
        if opt.verbose:
            print('crit {:.3f}'.format(time.time() - start))

        # backward
        start = time.time()
        loss.backward()
        if opt.verbose:
            print('loss {:.3f}'.format(time.time() - start))

        # show information
        train_loss = loss.data[0]

    return train_loss, reward_mean, sample_mean, greedy_mean
示例#9
0
def eval_split(model_cnn, model, crit, loader, eval_kwargs={}):

    verbose_eval = eval_kwargs.get('verbose_eval', True)
    val_images_use = eval_kwargs.get('val_images_use', -1)
    split = eval_kwargs.get('split', 'val')
    lang_eval = eval_kwargs.get('language_eval', 1)
    beam_size = eval_kwargs.get('beam_size', 1)
    coco_caption_path = eval_kwargs.get('coco_caption_path', 'coco-caption')
    caption_model = eval_kwargs.get('caption_model', '')
    batch_size = eval_kwargs.get('batch_size', 2)
    seq_per_img = eval_kwargs.get('seq_per_img', 5)
    id = eval_kwargs.get('id', '')
    input_anno = eval_kwargs.get('input_anno', '')
    is_compute_val_loss = eval_kwargs.get('is_compute_val_loss', 0)

    # aic caption path
    is_aic_data = eval_kwargs.get('is_aic_data', False)
    aic_caption_path = eval_kwargs.get('aic_caption_path', 'aic-caption')

    # Make sure in the evaluation mode
    model_cnn.eval()
    model.eval()

    if crit is None:
        is_compute_val_loss = 0

    if is_compute_val_loss == 1 and seq_per_img > 1:
        fc_expander = utils.FeatExpander(seq_per_img)
        att_expander = utils.FeatExpander(seq_per_img)
        bu_expander = None
        if models.has_bu(caption_model) or models.has_sub_region_bu(
                caption_model):
            bu_expander = utils.FeatExpander(seq_per_img)

    loader.reset_iterator(split)

    n = 0
    loss_sum = 0
    loss_evals = 0
    predictions = []
    vocab = loader.get_vocab()
    vocab_size = loader.get_vocab_size()
    while True:

        start = time.time()

        data = loader.get_batch(split, batch_size)
        n = n + batch_size

        images = data['images']
        labels = data['labels']
        masks = data['masks']
        tokens = data['tokens']

        images.volatile = True
        labels.volatile = True
        masks.volatile = True
        tokens.volatile = True

        fc_feats, att_feats, bu_feats = compute_cnn_feats(
            caption_model, model_cnn, images)

        if models.has_bu(caption_model):
            bu_feats = data['bus']
            bu_feats.volatile = True

        if is_compute_val_loss == 1:
            loss = compute_loss(crit, model, caption_model, seq_per_img,
                                fc_expander, att_expander, bu_expander,
                                fc_feats, att_feats, bu_feats, labels, masks,
                                tokens)
            loss_sum = loss_sum + loss
            loss_evals = loss_evals + 1
        else:
            loss = 0

        output = compute_output(caption_model, beam_size, model, fc_feats,
                                att_feats, bu_feats)

        seq = output[0]

        #
        if type(seq) == type([]):
            seq = seq[-1]
        if is_aic_data:
            sents = utils.decode_sequence_aic(vocab, seq)
            # captions = utils.decode_sequence(vocab, seq)
        else:
            sents = utils.decode_sequence(vocab, seq)
            # captions = utils.decode_sequence(vocab, seq)

        # print(sents)
        # print(captions)

        for k, sent in enumerate(sents):
            if is_aic_data:
                image_id = data['infos'][k]['image_id']
            else:
                image_id = data['infos'][k]['id']
            entry = {'image_id': image_id, 'caption': sent}
            # caption = {'image_id': image_id, 'caption': captions[k]}
            predictions.append(entry)
            if verbose_eval:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']

        if val_images_use != -1:
            ix1 = min(ix1, val_images_use)
        for i in range(n - ix1):
            predictions.pop()
        if verbose_eval:
            span_time = time.time() - start
            left_time = (ix1 - ix0) * span_time / batch_size
            s_left_time = utils.format_time(left_time)
            print('evaluating validation preformance... %d/%d %.3fs left:%s' %
                  (ix0 - 1, ix1, span_time, s_left_time))

        if data['bounds']['wrapped']:
            break
        if n >= val_images_use:
            break

    if lang_eval == 1:
        if is_aic_data:
            lang_stats, str_stats = language_eval_aic(id, predictions,
                                                      aic_caption_path,
                                                      input_anno)
        else:
            lang_stats, str_stats = language_eval(id, predictions,
                                                  coco_caption_path,
                                                  input_anno)

    # Switch back to training mode
    model_cnn.train()
    model.train()

    if is_compute_val_loss == 1:
        final_loss = loss_sum / loss_evals
    else:
        final_loss = 0

    return final_loss, predictions, lang_stats, str_stats
示例#10
0
def compute_loss(crit, model, caption_model, seq_per_img, fc_expander,
                 att_expander, bu_expander, fc_feats, att_feats, bu_feats,
                 labels, masks, tokens):

    if models.is_only_fc_feat(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
        else:
            fc_feats_ext = fc_feats
        batch_outputs = model(fc_feats_ext, labels)
    elif models.is_only_att_feat(caption_model):
        if seq_per_img > 1:
            att_feats_ext = att_expander(att_feats)
        else:
            att_feats_ext = att_feats
        batch_outputs = model(att_feats_ext, labels)
    elif caption_model == "SCST":
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats
        batch_outputs, _ = model(fc_feats_ext, att_feats_ext, labels, "train")
    elif models.is_prob_weight(caption_model):
        if models.has_sub_region_bu(caption_model):
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
                bu_feats_ext = bu_expander(bu_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats
                bu_feats_ext = bu_feats

            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            if seq_per_img > 1:
                fc_feats_ext = fc_expander(fc_feats)
                att_feats_ext = att_expander(att_feats)
            else:
                fc_feats_ext = fc_feats
                att_feats_ext = att_feats

            if models.has_bu(caption_model):
                if seq_per_img > 1:
                    bu_feats_ext = bu_expander(bu_feats)
                else:
                    bu_feats_ext = bu_feats
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              bu_feats_ext, labels)
            else:
                batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                              labels)
    elif models.is_prob_weight_mul_out(caption_model):
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext,
                                          bu_feats_ext, labels)
        else:
            batch_outputs, prob_w = model(fc_feats_ext, att_feats_ext, labels)
    else:
        if seq_per_img > 1:
            fc_feats_ext = fc_expander(fc_feats)
            att_feats_ext = att_expander(att_feats)
        else:
            fc_feats_ext = fc_feats
            att_feats_ext = att_feats

        if models.has_bu(caption_model):
            if seq_per_img > 1:
                bu_feats_ext = bu_expander(bu_feats)
            else:
                bu_feats_ext = bu_feats
            batch_outputs = model(fc_feats_ext, att_feats_ext, bu_feats_ext,
                                  labels)
        else:
            batch_outputs = model(fc_feats_ext, att_feats_ext, labels)

    if models.is_prob_weight(caption_model) or models.is_prob_weight_mul_out(
            caption_model):
        loss = crit(batch_outputs, labels, masks, prob_w, tokens)
    else:
        loss = crit(batch_outputs, labels, masks)
    loss.backward()

    return loss.data[0]
示例#11
0
def main():

    opt = parse_args()

    # make dirs
    print(opt.eval_result_path)
    if not os.path.isdir(opt.eval_result_path):
        os.makedirs(opt.eval_result_path)

    # Load infos
    infos = load_infos(opt)

    ignore = [
        "id", "input_json", "input_h5", "input_anno", "images_root",
        "coco_caption_path", "batch_size", "beam_size", "start_from_best",
        "eval_result_path"
    ]
    for k in vars(infos['opt']).keys():
        if k not in ignore:
            if k in vars(opt):
                assert vars(opt)[k] == vars(
                    infos['opt'])[k], k + ' option not consistent'
            else:
                vars(opt).update({k: vars(infos['opt'])[k]
                                  })  # copy over options from model

    # print(opt)

    # Setup the model
    model_cnn = models.setup_cnn(opt)
    model_cnn.cuda()

    model = models.setup(opt)
    model.cuda()

    # Make sure in the evaluation mode
    model_cnn.eval()
    model.eval()

    if models.has_bu(opt.caption_model) or \
            models.has_sub_regions(opt.caption_model) or \
            models.has_sub_region_bu(opt.caption_model):
        loader = DataLoaderThreadBu(opt)
        print("DataLoaderThreadBu")
    else:
        loader = DataLoaderThreadNew(opt)
        print("DataLoaderThreadNew")

    loader.ix_to_word = infos['vocab']

    eval_kwargs = {'split': opt.val_split, 'dataset': opt.input_json}
    eval_kwargs.update(vars(opt))

    start_beam = 0
    total_beam = 20
    for beam in range(start_beam, total_beam):
        opt.beam_size = beam + 1
        eval_kwargs.update(vars(opt))
        print("beam_size: " + str(opt.beam_size))
        print("start eval ...")
        crit = None
        val_loss, predictions, lang_stats, str_stats = eval_utils.eval_split(
            model_cnn, model, crit, loader, eval_kwargs)
        print("end eval ...")
        msg = "str_stats = {}".format(str_stats)
        print(msg)
        save_result(str(opt.beam_size) + "," + str_stats, predictions, opt)
def main():

    opt = parse_args()

    opt.datasets = opt.datasets.split(',')
    opt.ids = opt.ids.split(',')

    # make dirs
    print(opt.output_dir)
    if not os.path.isdir(opt.output_dir):
        os.makedirs(opt.output_dir)

    print(opt.output_beam_dir)
    if not os.path.isdir(opt.output_beam_dir):
        os.makedirs(opt.output_beam_dir)

    # print(opt)

    all_model_cnns = []
    all_models = []

    for i in range(len(opt.ids)):

        # id
        opt.id = opt.ids[i]

        # Load infos
        infos = load_infos(opt)

        ignore = ["id", "batch_size", "beam_size", "start_from_best", "input_json",
                  "input_h5", "input_anno", "images_root", "aic_caption_path", "input_bu"]

        for k in vars(infos['opt']).keys():
            if k not in ignore:
                vars(opt).update({k: vars(infos['opt'])[k]})

        opt.relu_type = 0

        # Setup the model
        model_cnn = models.setup_cnn(opt)
        # model_cnn.cuda()
        model_cnn = nn.DataParallel(model_cnn.cuda())

        model = models.setup(opt)
        model.cuda()

        # Make sure in the evaluation mode
        model_cnn.eval()
        model.eval()

        all_model_cnns.append(model_cnn)
        all_models.append(model)

    if opt.eval_type == 0: # local test

        print('eval local')

        if models.has_bu(opt.caption_model):
            loader = DataLoaderThreadBu(opt)
        else:
            loader = DataLoaderThreadNew(opt)

        # Set sample options
        predictions, lang_stats, str_stats, beam_vis = eval_split(all_model_cnns, all_models, loader, opt, vars(opt))

        save_result(opt.output_dir, str_stats, predictions)

        save_beam_vis_result(opt.output_beam_dir, "eval_beam_vis.json", beam_vis)


    elif opt.eval_type == 1: # server

        print('eval server')

        for dataset in opt.datasets:

            print(os.path.join(opt.image_folder, dataset))

            loader = DataLoaderRaw({'folder_path': os.path.join(opt.image_folder, dataset),
                                    'batch_size': opt.batch_size,
                                    'start': opt.start,
                                    'num': opt.num,
                                    'use_bu_att': opt.use_bu_att,
                                    'input_bu': opt.input_bu,
                                    'bu_size': opt.bu_size,
                                    'bu_feat_size': opt.bu_feat_size})

            loader.ix_to_word = infos['vocab']

            # Set sample options
            predictions, lang_stats, str_stats, beam_vis = eval_split(all_model_cnns, all_models, loader, opt, vars(opt))

            path_json = opt.output_dir + '/captions_' + dataset + str(opt.start) + '_ensemble_results.json'

            json.dump(predictions, open(path_json, 'w'))

            save_beam_vis_result(opt.output_beam_dir, dataset + str(opt.start) + "_beam_size_" + str(opt.beam_size) + "_beam_type_" + str(opt.beam_type) + "_eval_beam_vis.json", beam_vis)
def eval_split(all_model_cnns, all_models, loader, opt, eval_kwargs={}):

    verbose_eval = eval_kwargs.get('verbose_eval', True)
    batch_size = eval_kwargs.get('batch_size', 1)
    aic_caption_path = eval_kwargs.get('aic_caption_path', 'coco-caption')
    num = eval_kwargs.get('num', 5000)
    eval_type = eval_kwargs.get('eval_type', 0)


    print('start eval ...')

    split = 'val'
    loader.reset_iterator(split)
    n = 0
    predictions = []
    vocab = loader.get_vocab()

    beam_vis = {}
    total_predicted_ids = []
    total_beam_parent_ids = []
    total_scores = []
    total_ids = []
    total_sents = []

    while True:

        start = time.time()

        start_data = time.time()
        data = loader.get_batch(split, batch_size)
        print('data time {:.3f} s'.format(time.time() - start_data))

        n = n + batch_size

        # images = torch.from_numpy(data['images']).cuda()
        # images = utils.prepro_norm(images, False)
        # images = Variable(images, requires_grad=False)

        images = data['images']

        if opt.eval_type == 1: # server
            images = torch.from_numpy(images).cuda()
            images = utils.prepro_norm(images, False)
            images = Variable(images, requires_grad=False)

        if models.has_bu(opt.caption_model):
            if opt.eval_type == 0: # local
                bu_feats = data['bus']
            elif opt.eval_type == 1: # server
                bus = torch.from_numpy(data['bus']).cuda().float()
                bu_feats = Variable(bus, requires_grad=False)
            seqs, _, batch_predicted_ids, batch_beam_parent_ids, batch_scores = sample_beam_ensamble_with_bu(all_model_cnns, all_models, images, bu_feats, eval_kwargs)
        else:
            seqs, _, batch_predicted_ids, batch_beam_parent_ids, batch_scores = sample_beam_ensamble(all_model_cnns, all_models, images, eval_kwargs)

        # sents
        sents = utils.decode_sequence_aic(vocab, seqs)

        for k, sent in enumerate(sents):
            print(data['infos'][k])
            if opt.eval_type == 0: # local
                image_id = data['infos'][k]['image_id']
            elif opt.eval_type == 1: # server
                image_id = data['infos'][k]['id']
            #if opt.eval_type == 1: # server
            #    image_id = int(image_id.split('_')[2])
            entry = {'image_id': image_id, 'caption': sent}

            predictions.append(entry)
            if verbose_eval:
                print('image %s: %s' % (entry['image_id'], entry['caption']))

            total_predicted_ids.append(batch_predicted_ids[k])
            total_beam_parent_ids.append(batch_beam_parent_ids[k])
            total_scores.append(batch_scores[k])
            total_ids.append(image_id)
            total_sents.append(sent)


        ix0 = data['bounds']['it_pos_now']
        ix1 = data['bounds']['it_max']
        if num != -1:
            ix1 = min(ix1, num)

        for i in range(n - ix1):
            predictions.pop()
        if verbose_eval:
            span_time = time.time() - start
            left_time = (ix1 - ix0) * span_time / batch_size
            s_left_time = format_time(left_time)
            print('evaluating validation preformance... %d/%d %.3fs left:%s' % (ix0, ix1, span_time, s_left_time))


        if data['bounds']['wrapped']:
            break
        if n != -1 and n >= num:
            break

        print('time {:.3f} s'.format(time.time() - start))

    beam_vis["predicted_ids"] = total_predicted_ids
    beam_vis["beam_parent_ids"] = total_beam_parent_ids
    beam_vis["scores"] = total_scores
    beam_vis["ids"] = total_ids
    beam_vis["vocab"] = vocab
    beam_vis["sents"] = total_sents

    # print(beam_vis)

    if opt.eval_type == 0: # local
        lang_stats, str_stats = language_eval_aic("ensemble", predictions, aic_caption_path, opt.input_anno)
    else:
        lang_stats = None
        str_stats = None

    return predictions, lang_stats, str_stats, beam_vis
示例#14
0
def train(opt):

    notifier = notify()
    notifier.login()

    # init path
    if not os.path.exists(opt.eval_result_path):
        os.makedirs(opt.eval_result_path)

    config_file = os.path.join(opt.eval_result_path, opt.id + '_config.txt')
    with open(config_file, 'w') as f:
        f.write("{}\n".format(json.dumps(vars(opt), sort_keys=True, indent=2)))

    torch.backends.cudnn.benchmark = True

    if opt.use_tensorboard:

        if opt.tensorboard_type == 0:
            board = tensorboard.TensorBoard()
            board.start(opt.id, opt.tensorboard_ip, opt.tensorboard_port)
        else:
            board = trans_client.TransClient()
            board.start(opt.id)

    print(opt.cnn_model)

    loader = get_loader()

    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    vocab = loader.get_vocab()
    opt.vocab = vocab
    batch_size = loader.batch_size

    infos = get_infos()
    infos['vocab'] = vocab

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    finetune_cnn_history = infos.get('finetune_cnn_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    else:
        best_val_score = None

    model_cnn = models.setup_cnn(opt)
    model_cnn = model_cnn.cuda()
    model_cnn = nn.DataParallel(model_cnn)

    model = models.setup(opt)
    model = model.cuda()
    # if models.is_transformer(opt.caption_model) or models.is_ctransformer(opt.caption_model):
    #     model = nn.DataParallel(model)

    train_utils.save_model_conf(model_cnn, model, opt)

    update_lr_flag = True

    model_cnn.train()
    model.train()

    fc_expander, att_expander, bu_expander = get_expander()

    optimizer = None
    optimizer_cnn = None
    finetune_cnn_start = False

    early_stop_cnt = 0

    params = {}
    params['model'] = model
    params['vocab'] = vocab

    # crit_pg, crit_rl, crit_ctc, crit_c, crit_ac, crit
    params['crit_pg'] = None
    params['crit_rl'] = None
    params['crit_ctc'] = None
    params['crit_c'] = None
    params['crit_ac'] = None
    params['crit'] = None

    is_eval_start = opt.is_eval_start

    if opt.use_auto_learning_rate == 1:
        train_process = train_utils.init_train_process()
        train_process_index = infos.get('train_process_index', 0)
        train_step = train_process[train_process_index]
        optimizer_cnn = None
        optimizer = None
        opt.learning_rate = train_step.learning_rate
        opt.cnn_learning_rate = train_step.cnn_learning_rate
        opt.finetune_cnn_after = train_step.finetune_cnn_after

    while True:

        current_score = None

        # make evaluation on validation set, and save model
        if (iteration > 0 and iteration % opt.save_checkpoint_every == 0 and
                not val_result_history.has_key(iteration)) or is_eval_start:

            predictions, best_val_score, best_flag, current_score = eval_model(
                model_cnn, model, params, loader, board, iteration, notifier,
                val_result_history, best_val_score)

            infos['best_val_score'] = best_val_score
            infos['val_result_history'] = val_result_history
            train_utils.save_infos(infos, opt)

            if best_flag:
                train_utils.save_best_result(predictions, opt)
                train_utils.save_model_best(model, model_cnn, infos, opt)
                early_stop_cnt = 0
            else:
                early_stop_cnt += 1

            is_eval_start = False

        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            msg = "max epoch"
            logger.info(msg)
            break

        # auto update model
        if opt.use_auto_learning_rate == 1 and current_score is not None:
            if early_stop_cnt > opt.auto_early_stop_cnt or current_score < opt.auto_early_stop_score:
                early_stop_cnt = 0
                train_process_index += 1
                msg = opt.id + " early stop " + str(train_process_index)
                logger.info(msg)

                infos['train_process_index'] = train_process_index
                train_utils.save_infos(infos, opt)

                if train_process_index >= len(train_process):
                    notifier.send(opt.id + " early stop", msg)
                    logger.info("break")
                    break

                train_step = train_process[train_process_index]
                optimizer_cnn = None
                optimizer = None
                opt.learning_rate = train_step.learning_rate
                opt.cnn_learning_rate = train_step.cnn_learning_rate
                opt.finetune_cnn_after = train_step.finetune_cnn_after
                opt.start_from_best = opt.auto_start_from_best

                # model_cnn_path = os.path.join(opt.auto_start_from_best, opt.id + '_model_cnn_best.pth')
                # model_cnn.load_state_dict(torch.load(model_cnn_path))
                # model_cnn = model_cnn.cuda()
                # model_cnn = nn.DataParallel(model_cnn)
                #
                # model_path = os.path.join(opt.auto_start_from_best, opt.id + '_model_best.pth')
                # model.load_state_dict(torch.load(model_path))
                # model = model.cuda()

                del model_cnn
                del model

                torch.cuda.empty_cache()

                model_cnn = models.setup_cnn(opt)
                model_cnn = model_cnn.cuda()
                model_cnn = nn.DataParallel(model_cnn)

                model = models.setup(opt)
                model = model.cuda()

                model_cnn.train()
                model.train()

                update_lr_flag = True

        # start train

        # Update the iteration and epoch
        iteration += 1

        if update_lr_flag:
            if opt.finetune_cnn_after >= 0 and epoch >= opt.finetune_cnn_after:
                finetune_cnn_start = True
            else:
                finetune_cnn_start = False

            optimizer_cnn = train_utils.get_cnn_optimizer(
                model_cnn, optimizer_cnn, finetune_cnn_start, opt)

            train_utils.update_lr(epoch, optimizer, optimizer_cnn,
                                  finetune_cnn_start, opt)

            update_lr_flag = False

        if opt.reinforce_start >= 0 and epoch >= opt.reinforce_start:
            use_reinforce = True
        else:
            use_reinforce = False

        optimizer = get_optimizer(optimizer, epoch, model, model_cnn)

        start_total = time.time()
        start = time.time()

        optimizer.zero_grad()
        if finetune_cnn_start:
            optimizer_cnn.zero_grad()

        # batch data
        data = loader.get_batch('train', batch_size)

        images = data['images']
        bus = None
        if models.has_bu(opt.caption_model):
            bus = data['bus']

        if opt.verbose:
            print('data {:.3f}'.format(time.time() - start))

        start = time.time()

        fc_feats, att_feats, bu_feats = train_cnn(model_cnn, images, bus,
                                                  fc_expander, att_expander,
                                                  bu_expander, use_reinforce)

        if opt.verbose:
            print('model_cnn {:.3f}'.format(time.time() - start))

        # get input data
        params['fc_feats'] = fc_feats
        params['att_feats'] = att_feats
        params['bu_feats'] = bu_feats

        # get target data
        params['labels'] = data['labels']
        params['masks'] = data['masks']
        params['tokens'] = data['tokens']
        params['gts'] = data['gts']
        params['targets'] = data['targets']

        # crit_pg, crit_rl, crit_ctc, crit_c, crit_ac, crit,
        train_loss, reward_mean, use_reinforce = train_model(
            params, iteration, epoch, board)

        # update the gradient
        update_gradient(optimizer, optimizer_cnn, finetune_cnn_start)

        time_batch = time.time() - start_total
        left_time = (opt.save_checkpoint_every -
                     iteration % opt.save_checkpoint_every) * time_batch
        s_left_time = utils.format_time(left_time)
        msg = "id {} iter {} (epoch {}), train_loss = {:.3f}, lr = {} lr_cnn = {} f_cnn = {} rf = {} r = {:.3f} early_stop_cnt = {} time/batch = {:.3f}s time/eval = {}" \
            .format(opt.id, iteration, epoch, train_loss, opt.current_lr, opt.current_cnn_lr, finetune_cnn_start,
                    use_reinforce, reward_mean, early_stop_cnt, time_batch, s_left_time)
        logger.info(msg)

        if opt.use_tensorboard:
            if iteration % opt.tensorboard_for_train_every == 0:
                board.loss_train(train_loss, iteration)

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if iteration % opt.losses_log_every == 0:
            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            finetune_cnn_history[iteration] = finetune_cnn_start

        # update infos
        infos['iter'] = iteration
        infos['epoch'] = epoch
        infos['iterators'] = loader.iterators
        infos['best_val_score'] = best_val_score
        infos['opt'] = opt
        infos['val_result_history'] = val_result_history
        infos['loss_history'] = loss_history
        infos['lr_history'] = lr_history
        infos['finetune_cnn_history'] = finetune_cnn_history
        if opt.use_auto_learning_rate == 1:
            infos['train_process_index'] = train_process_index

        if opt.save_snapshot_every > 0 and iteration % opt.save_snapshot_every == 0:
            train_utils.save_model(model, model_cnn, infos, opt)

    loader.terminate()