def update_lr(opt, epoch, model, optimizer_G):

    # Assign the learning rate
    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
        frac = (epoch -
                opt.learning_rate_decay_start) // opt.learning_rate_decay_every
        decay_factor = opt.learning_rate_decay_rate**frac
        opt.current_lr = opt.learning_rate * decay_factor
        utils.set_lr(optimizer_G, opt.current_lr)  # set the decayed rate
    else:
        opt.current_lr = opt.learning_rate
    # Assign the scheduled sampling prob
    if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
        frac = (epoch - opt.scheduled_sampling_start
                ) // opt.scheduled_sampling_increase_every
        opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                          opt.scheduled_sampling_max_prob)
        model.ss_prob = opt.ss_prob

    # If start self critical training
    if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
        sc_flag = True
    else:
        sc_flag = False

    update_lr_flag = False
    return opt, sc_flag, update_lr_flag, model, optimizer_G
Beispiel #2
0
def train(opt):
    if vars(opt).get('start_from', None) is not None:
        opt.checkpoint_path = opt.start_from
        opt.id = opt.checkpoint_path.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path))
    else:
        opt.id = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)

        if not os.path.exists(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        print('Create folder: {}'.format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    # opt.use_att = False
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_UP(opt)
    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories.pkl')):
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    print('### Model summary below###\n {}\n'.format(str(model)))
    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('model parameter:{}'.format(model_params))

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        if getattr(opt, 'use_ssg', 0) == 1:
            if getattr(opt, 'use_isg', 0) == 1:
                tmp = [
                    data['fc_feats'], data['labels'], data['masks'],
                    data['att_feats'], data['att_masks'],
                    data['isg_rela_matrix'], data['isg_rela_masks'],
                    data['isg_obj'], data['isg_obj_masks'], data['isg_attr'],
                    data['isg_attr_masks'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks']
                ]

                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, labels, masks, att_feats, att_masks, \
                isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \
                ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp

                # image graph domain
                isg_data = {}
                isg_data['att_feats'] = att_feats
                isg_data['att_masks'] = att_masks

                isg_data['isg_rela_matrix'] = isg_rela_matrix
                isg_data['isg_rela_masks'] = isg_rela_masks
                isg_data['isg_obj'] = isg_obj
                isg_data['isg_obj_masks'] = isg_obj_masks
                isg_data['isg_attr'] = isg_attr
                isg_data['isg_attr_masks'] = isg_attr_masks
                # text graph domain
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
            else:
                tmp = [
                    data['fc_feats'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks'], data['labels'], data['masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr

                isg_data = None
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        if not sc_flag:
            # loss = crit(dp_model(model_zh,model_en,itow_zh,itow, fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:])
            loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   isg_data,
                                                   ssg_data,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, isg_data,
                                              ssg_data, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, train_loss, end - start))
            else:
                print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration %10 == 0) and (iteration != 0):
        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            # eval model
            if use_rela:
                eval_kwargs = {
                    'split': 'val',
                    'dataset': opt.input_json,
                    'use_real': 1
                }
            else:
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(model_zh,model_en,itow_zh,itow, dp_model, crit, loader, eval_kwargs)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                          'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #3
0
def train(opt):
    logger = initialize_logger(os.path.join(opt.checkpoint_path, 'train.log'))
    print = logger.info

    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Print out the option variables
    print("*" * 20)
    for k, v in opt.__dict__.items():
        print("%r: %r" % (k, v))
    print("*" * 20)

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos.json'), 'r') as f:
            infos = json.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    else:
        best_val_score = None

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    start_time = time.time()
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if 0 <= opt.learning_rate_decay_start < epoch:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if 0 <= opt.scheduled_sampling_start < epoch:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer()
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        batch_data = loader.get_batch('train')
        torch.cuda.synchronize()

        tmp = [
            batch_data['fc_feats'], batch_data['att_feats'],
            batch_data['labels'], batch_data['masks'], batch_data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            outputs = dp_model(fc_feats, att_feats, labels, att_masks)
            loss = crit(outputs, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, batch_data,
                                              gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data
        torch.cuda.synchronize()

        # Update the iteration and epoch
        iteration += 1
        if batch_data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print train loss or avg reward
        if iteration % opt.losses_print_every == 0:
            if not sc_flag:
                print(
                    "iter {} (epoch {}), loss = {:.3f}, time = {:.3f}".format(
                        iteration, epoch, loss.item(),
                        time.time() - start_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time = {:.3f}".
                      format(iteration, epoch, np.mean(reward[:, 0]),
                             time.time() - start_time))
            start_time = time.time()

        # make evaluation on validation set, and save model
        if (opt.save_checkpoint_every > 0 and iteration % opt.save_checkpoint_every == 0)\
                or (opt.save_checkpoint_every <= 0 and update_lr_flag):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.simple_eval_split(
                dp_model, loader, eval_kwargs)

            # Save model if is improving on validation result
            if not os.path.exists(opt.checkpoint_path):
                os.makedirs(opt.checkpoint_path)

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False

            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscellaneous information
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = vars(opt)
            infos['vocab'] = loader.get_vocab()

            with open(os.path.join(opt.checkpoint_path, 'infos.json'),
                      'w') as f:
                json.dump(infos, f, sort_keys=True, indent=4)

            if best_flag:
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, 'infos-best.json'),
                          'w') as f:
                    json.dump(infos, f, sort_keys=True, indent=4)

            # Stop if reaching max epochs
            if opt.max_epochs != -1 and epoch >= opt.max_epochs:
                break
Beispiel #4
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #5
0
def train(opt):
    # tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)
    if not os.path.exists(opt.checkpoint_path):
        os.makedirs(opt.checkpoint_path)

    with open(os.path.join(opt.checkpoint_path, 'config.json'), 'w') as f:
        json.dump(vars(opt), f, indent=4)

    writer = None
    if tb is not None:
        import shutil
        now = datetime.now()
        if opt.reset_tensorboard:
            for d in os.listdir(opt.checkpoint_path):
                d = os.path.join(opt.checkpoint_path, d)
                if os.path.isdir(d) and 'tb_' in d:
                    shutil.rmtree(d)
                    print('remove', d)
        logdir = os.path.join(opt.checkpoint_path,
                              'tb_' + now.strftime("%Y%m%d-%H%M%S") + "/")
        writer = tb.SummaryWriter(logdir)

    # Load iterators
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.vocab = loader.get_vocab()
    opt.blank_token = loader.get_blank_token()
    opt.seq_length = loader.seq_length

    opt.unique_characters = loader.unique_characters
    opt.max_characters = loader.max_characters
    if opt.glove is not None:
        opt.glove_npy = loader.build_glove(opt.glove)
    else:
        opt.glove_npy = None

    # set up models
    gen_model = FillInCharacter(opt)
    gen_model = gen_model.cuda()

    if torch.cuda.device_count() > 1:
        gen_model = nn.DataParallel(gen_model)
    gen_model.train()
    gen_optimizer = utils.build_optimizer(gen_model.parameters(), opt)

    # keep track of iteration
    g_iter = 0
    g_epoch = 0
    update_lr_flag = True

    # Load from checkpoint path
    infos = {'opt': opt}
    histories = {}
    infos['vocab'] = loader.get_vocab()
    if opt.start_from is not None:
        # Open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos.pkl'), 'rb') as f:
            infos = pickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # Load train/val histories
        with open(os.path.join(opt.start_from, 'histories.pkl'), 'rb') as f:
            histories = pickle.load(f)

        # Load generator
        start_epoch = opt.start_epoch
        g_model_path = os.path.join(opt.start_from, "gen_%s.pth" % start_epoch)
        g_optimizer_path = os.path.join(opt.start_from,
                                        "gen_optimizer_%s.pth" % start_epoch)
        assert os.path.isfile(g_model_path) and os.path.isfile(
            g_optimizer_path)
        gen_model.load_state_dict(torch.load(g_model_path))
        gen_optimizer.load_state_dict(torch.load(g_optimizer_path))
        if "latest" not in start_epoch and "best" != start_epoch:
            g_epoch = int(start_epoch) + 1
            g_iter = (g_epoch) * loader.split_size['train'] // opt.batch_size
        elif start_epoch == "best":
            g_epoch = infos['g_epoch_' + start_epoch] + 1
            g_iter = (g_epoch) * loader.split_size['train'] // opt.batch_size
        else:
            g_epoch = infos['g_epoch_' + start_epoch] + 1
            g_iter = infos['g_iter_' + start_epoch]
        print('loaded %s (epoch: %d iter: %d)' %
              (g_model_path, g_epoch, g_iter))
    infos['opt'] = opt
    loader.iterators = infos.get('g_iterators', loader.iterators)

    # misc
    best_val_score = infos.get('g_best_score', None)
    opt.seq_length = loader.seq_length
    opt.video = 1
    g_val_result_history = histories.get('g_val_result_history', {})
    g_loss_history = histories.get('g_loss_history', {})
    """ START TRAINING """
    while g_epoch < opt.pre_nepoch:
        # gc.collect()
        # set every epoch
        if update_lr_flag:
            # Assign the learning rate for generator
            if g_epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (g_epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(gen_optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if g_epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (g_epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                gen_model.ss_prob = opt.ss_prob

            update_lr_flag = False
        """ TRAIN GENERATOR """
        gen_model.train()
        start = time.time()
        gen_loss, wrapped, sent_num = train_generator(gen_model, gen_optimizer,
                                                      loader, opt.grad_clip)
        end = time.time()

        # Print Info
        if g_iter % opt.losses_print_every == 0:
            print("g_iter {} (g_epoch {}), gen_loss = {:.3f}, time/batch = {:.3f}" \
                .format(g_iter, g_epoch, gen_loss, end - start))

        # Log Losses
        if g_iter % opt.losses_log_every == 0:
            g_loss = gen_loss
            loss_history = {'g_loss': g_loss, 'g_epoch': g_epoch}
            g_loss_history[g_iter] = loss_history
            log_metrics(writer, g_iter, loss_history)

        # Update the iteration
        g_iter += 1

        #########################
        # Evaluate & Save Model #
        #########################
        if True or wrapped:
            # evaluate model on dev set
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'sample_max': 1,
                'eval_accuracy': opt.eval_accuracy,
                'id': opt.val_id,
                'val_videos_use': opt.val_videos_use,
                'remove': 1
            }  # remove generated caption
            val_loss, predictions, accuracy = eval_split(
                gen_model, loader, eval_kwargs=eval_kwargs)
            if opt.eval_accuracy == 1:
                current_score = accuracy[
                    'Class Accuracy'] if 'Class Accuracy' in accuracy else accuracy[
                        'Instance Accuracy']
            else:
                current_score = -val_loss
            g_val_result_history[g_epoch] = {
                'g_val_loss': val_loss,
                'g_val_score': current_score
            }
            print('validation:', g_val_result_history[g_epoch])

            # Save the best generator model
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'gen_best.pth')
                torch.save(
                    gen_optimizer.state_dict(),
                    os.path.join(opt.checkpoint_path,
                                 'gen_optimizer_best.pth'))
                infos['g_epoch_best'] = g_epoch
                infos['g_iter_best'] = g_iter
                infos['g_best_score'] = best_val_score
                torch.save(gen_model.state_dict(), checkpoint_path)
                print("best fill in model saved to {} with score {}".format(
                    checkpoint_path, current_score))

            # Dump miscalleous informations and save
            infos['g_epoch_latest'] = g_epoch
            infos['g_iter_latest'] = g_iter
            infos['g_iterators'] = loader.iterators
            histories['g_val_result_history'] = g_val_result_history
            histories['g_loss_history'] = g_loss_history
            with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                      'wb') as f:
                pickle.dump(infos, f)
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                      'wb') as f:
                pickle.dump(histories, f)
            log_metrics(writer, g_iter, g_val_result_history[g_epoch])

            # save the latest model
            if opt.save_checkpoint_every > 0 and g_epoch % opt.save_checkpoint_every == 0:
                torch.save(
                    gen_model.state_dict(),
                    os.path.join(opt.checkpoint_path, 'gen_%d.pth' % g_epoch))
                torch.save(gen_model.state_dict(),
                           os.path.join(opt.checkpoint_path, 'gen_latest.pth'))
                torch.save(
                    gen_optimizer.state_dict(),
                    os.path.join(opt.checkpoint_path,
                                 'gen_optimizer_%d.pth' % g_epoch))
                torch.save(
                    gen_optimizer.state_dict(),
                    os.path.join(opt.checkpoint_path,
                                 'gen_optimizer_latest.pth'))
                print("fill in model saved to {} at epoch {}".format(
                    opt.checkpoint_path, g_epoch))

            # update epoch and lr
            g_epoch += 1
            update_lr_flag = True
Beispiel #6
0
def train(opt):
    logger = Logger(opt)
    flag = Flag(D_iters=opt.D_iter, G_iters=opt.G_iter, always=opt.always)
    ################### set up dataset and dataloader ########################
    dataset = VISTDataset(opt)
    opt.vocab_size = dataset.get_vocab_size()
    opt.seq_length = dataset.get_story_length()

    dataset.set_option(data_type={
        'whole_story': False,
        'split_story': True,
        'caption': False
    })

    dataset.train()
    train_loader = DataLoader(dataset,
                              batch_size=opt.batch_size,
                              shuffle=opt.shuffle,
                              num_workers=opt.workers)
    dataset.val()
    val_loader = DataLoader(dataset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            num_workers=opt.workers)

    ##################### set up model, criterion and optimizer ######
    bad_valid = 0

    # set up evaluator
    evaluator = Evaluator(opt, 'val')

    # set up criterion
    crit = criterion.LanguageModelCriterion()
    rl_crit = criterion.ReinforceCriterion(opt, dataset)

    # set up model
    model = models.setup(opt)
    model.cuda()
    disc_opt = copy.copy(opt)
    disc_opt.model = 'RewardModel'
    disc = models.setup(disc_opt)
    if os.path.exists(os.path.join(logger.log_dir, 'disc-model.pth')):
        logging.info("loading pretrained RewardModel")
        disc.load_state_dict(
            torch.load(os.path.join(logger.log_dir, 'disc-model.pth')))
    disc.cuda()

    # set up optimizer
    optimizer = setup_optimizer(opt, model)
    disc_optimizer = setup_optimizer(opt, disc)

    dataset.train()
    model.train()
    disc.train()
    ############################## training ##################################
    for epoch in range(logger.epoch_start, opt.max_epochs):
        # Assign the scheduled sampling prob

        start = time.time()
        for iter, batch in enumerate(train_loader):
            logger.iteration += 1
            torch.cuda.synchronize()

            feature_fc = Variable(batch['feature_fc']).cuda()
            target = Variable(batch['split_story']).cuda()
            index = batch['index']

            optimizer.zero_grad()
            disc_optimizer.zero_grad()

            if flag.flag == "Disc":
                model.eval()
                disc.train()
                if opt.decoding_method_DISC == 'sample':
                    seq, seq_log_probs, baseline = model.sample(
                        feature_fc,
                        sample_max=False,
                        rl_training=True,
                        pad=True)
                elif opt.decoding_method_DISC == 'greedy':
                    seq, seq_log_probs, baseline = model.sample(
                        feature_fc,
                        sample_max=True,
                        rl_training=True,
                        pad=True)
            else:
                model.train()
                disc.eval()
                seq, seq_log_probs, baseline = model.sample(feature_fc,
                                                            sample_max=False,
                                                            rl_training=True,
                                                            pad=True)

            seq = Variable(seq).cuda()
            mask = (seq > 0).float()
            mask = to_contiguous(
                torch.cat([
                    Variable(
                        mask.data.new(mask.size(0), mask.size(1), 1).fill_(1)),
                    mask[:, :, :-1]
                ], 2))
            normed_seq_log_probs = (seq_log_probs *
                                    mask).sum(-1) / mask.sum(-1)

            gen_score = disc(seq.view(-1, seq.size(2)),
                             feature_fc.view(-1, feature_fc.size(2)))

            if flag.flag == "Disc":
                gt_score = disc(target.view(-1, target.size(2)),
                                feature_fc.view(-1, feature_fc.size(2)))
                loss = -torch.sum(gt_score) + torch.sum(gen_score)

                avg_pos_score = torch.mean(gt_score)
                avg_neg_score = torch.mean(gen_score)

                if logger.iteration % 5 == 0:
                    logging.info("pos reward {} neg reward {}".format(
                        avg_pos_score.data[0], avg_neg_score.data[0]))
                    print(
                        "PREDICTION: ",
                        utils.decode_story(dataset.get_vocab(),
                                           seq[:1].data)[0])
                    print(
                        "GROUND TRUTH: ",
                        utils.decode_story(dataset.get_vocab(),
                                           target[:1].data)[0])
            else:
                rewards = Variable(gen_score.data -
                                   0.001 * normed_seq_log_probs.data)
                #with open("/tmp/reward.txt", "a") as f:
                #    print(" ".join(map(str, rewards.data.cpu().numpy())), file=f)
                loss, avg_score = rl_crit(seq.data, seq_log_probs, baseline,
                                          index, rewards)
                # if logger.iteration % opt.losses_log_every == 0:
                avg_pos_score = torch.mean(gen_score)
                logging.info("average reward: {} average IRL score: {}".format(
                    avg_score.data[0], avg_pos_score.data[0]))

            if flag.flag == "Disc":
                loss.backward()
                nn.utils.clip_grad_norm(disc.parameters(),
                                        opt.grad_clip,
                                        norm_type=2)
                disc_optimizer.step()
            else:
                tf_loss = crit(model(feature_fc, target), target)
                print("rl_loss / tf_loss = ", loss.data[0] / tf_loss.data[0])
                loss = opt.rl_weight * loss + (1 - opt.rl_weight) * tf_loss
                loss.backward()
                nn.utils.clip_grad_norm(model.parameters(),
                                        opt.grad_clip,
                                        norm_type=2)
                optimizer.step()

            train_loss = loss.data[0]
            torch.cuda.synchronize()

            # Write the training loss summary
            if logger.iteration % opt.losses_log_every == 0:
                logger.log_training(epoch, iter, train_loss, opt.learning_rate,
                                    model.ss_prob)
                logging.info(
                    "Epoch {} Train {} - Iter {} / {}, loss = {:.5f}, time used = {:.3f}s"
                    .format(epoch, flag.flag, iter, len(train_loader),
                            train_loss,
                            time.time() - start))
                start = time.time()

            if logger.iteration % opt.save_checkpoint_every == 0:
                if opt.always is None:
                    # Evaluate on validation dataset and save model for every epoch
                    val_loss, predictions, metrics = evaluator.eval_story(
                        model, crit, dataset, val_loader, opt)
                    if opt.metric == 'XE':
                        score = -val_loss
                    else:
                        score = metrics[opt.metric]
                    logger.log_checkpoint(epoch, val_loss, metrics,
                                          predictions, opt, model, dataset,
                                          optimizer)
                    # halve the learning rate if not improving for a long time
                    if logger.best_val_score > score:
                        bad_valid += 1
                        if bad_valid >= 10:
                            opt.learning_rate = opt.learning_rate / 2.0
                            logging.info("halve learning rate to {}".format(
                                opt.learning_rate))
                            checkpoint_path = os.path.join(
                                logger.log_dir, 'model-best.pth')
                            model.load_state_dict(torch.load(checkpoint_path))
                            utils.set_lr(
                                optimizer,
                                opt.learning_rate)  # set the decayed rate
                            bad_valid = 0
                            logging.info("bad valid : {}".format(bad_valid))
                    else:
                        logging.info("achieving best {} score: {}".format(
                            opt.metric, score))
                        bad_valid = 0
                else:
                    torch.save(disc.state_dict(),
                               os.path.join(logger.log_dir, 'disc-model.pth'))
            flag.inc()
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}

    if opt.start_from is not None:
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
                                     attention_module_kwargs={'m': 40})
    decoder = MeshedDecoder(8668, 180, 3, 0)
    models = Transformer(8667, encoder, decoder)
    # Create model
    model = models.cuda()
    lang_model = Seq2Seq().cuda()
    # Create model
    model.load_state_dict(torch.load('./log_cvpr_mesh/all2model20000.pth'))
    lang_model.load_state_dict(torch.load('log_cvpr/all2model16000.pth'), strict=False)
    optimizer = utils.build_optimizer_adam(list(models.parameters()) + list(lang_model.parameters()), opt)

    update_lr_flag = True


    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp
        labels = labels.long()
        batchsize = fc_feats.size(0)
        labels_decode = labels.view(-1, 180)
        captions = utils.decode_sequence(loader.get_vocab(), labels_decode, None)
        captions_all = []
        for index, caption in enumerate(captions):
            caption = caption.replace('<start>', '').replace(' ,', '').replace('  ', ' ')
            captions_all.append(caption)

        # Forward pass and loss
        d_steps = 1
        g_steps = 1
        #print (torch.sum(labels!=0), torch.sum(masks!=0))
        if 1:




          if 1:
              model.train()
              optimizer.zero_grad()
              wordact, x_all_image = model(att_feats, labels.view(batchsize, -1))
              wordact_t = wordact[:,:-1,:]
              wordact_t = wordact_t.contiguous().view(wordact_t.size(0) * wordact_t.size(1), -1)
              labels_flat = labels.view(batchsize,-1)
              wordclass_v = labels_flat[:, 1:]
              wordclass_t = wordclass_v.contiguous().view(\
               wordclass_v.size(0) * wordclass_v.size(1), -1)
              loss_xe = F.cross_entropy(wordact_t[ ...], \
               wordclass_t[...].contiguous().view(-1))
              '''
              wordact = lang_model(labels.view(batchsize, -1).transpose(1, 0), labels.view(batchsize, -1).transpose(1, 0),
                               fc_feats)
              wordact_t = wordact.transpose(1, 0)[:, 1:, :]
              wordact_t = wordact_t.contiguous().view(wordact_t.size(0) * wordact_t.size(1), -1)
              labels_flat = labels.view(batchsize, -1)
              wordclass_v = labels_flat[:, 1:]
              wordclass_t = wordclass_v.contiguous().view( \
                  wordclass_v.size(0) * wordclass_v.size(1), -1)
              loss_xe_lang = F.cross_entropy(wordact_t[...], wordclass_t[...].view(-1))
              '''
              outcap, sampled_ids, sample_logprobs= lang_model.sample(labels.view(batchsize, -1).transpose(1,0),labels.view(batchsize, -1).transpose(1,0), fc_feats, loader.get_vocab())
              sampled_ids[:, 0] = 8667
              logprobs_input, _ = model(att_feats, sampled_ids.long().cuda())
              log_probs = F.log_softmax(logprobs_input[:, :-1, :], -1)

              sample_logprobs_true = log_probs.gather(2, sampled_ids[:, 1:].cuda().long().unsqueeze(2))



              with torch.no_grad():
                  reward, cider_sample, cider_greedy = get_self_critical_reward(batchsize, lang_model, labels.view(batchsize, -1).transpose(1,0), fc_feats, outcap,
                                                                                captions_all, loader,
                                                                                180)

              print (np.mean(cider_greedy))
              loss_rl1 = rl_crit(torch.exp(sample_logprobs_true.squeeze()) / torch.exp(sample_logprobs[:, 1:]).cuda().detach(),sampled_ids[:, 1:].cpu(), torch.from_numpy(reward).float().cuda())

              #loss_rl = rl_crit(sample_logprobs, sampled_ids.cpu(), torch.from_numpy(reward).float()).cuda()
              #x_all_langauge = x_all_langauge.cuda().detach()
              #l2_loss = ((x_all_image.transpose(2,1).cuda() - x_all_langauge) ** 2).mean().cuda()
              train_loss = loss_xe + loss_rl1 # + loss_xe_lang
              train_loss.backward()
              optimizer.step()

          if 1:
            if iteration % opt.print_freq == 1:
              print('Read data:', time.time() - start)
              if not sc_flag:
                  print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}" \
                    .format(iteration, epoch, loss_xe, data_time))
              else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

          # Update the iteration and epoch
          iteration += 1
          if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

          # Write the training loss summary
          if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
          if (iteration % opt.save_checkpoint_every == 0):
            checkpoint_path = os.path.join(opt.checkpoint_path, 'all2model{:05d}.pth'.format(iteration))
            torch.save(model.state_dict(), checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path, 'lang_model{:05d}.pth'.format(iteration))
            torch.save(lang_model.state_dict(), checkpoint_path)
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)
Beispiel #8
0
def train(opt):
    # tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)
    if not os.path.exists(opt.checkpoint_path):
        os.mkdir(opt.checkpoint_path)

    with open(os.path.join(opt.checkpoint_path,'config.json'),'w') as f:
        json.dump(vars(opt),f)

    # Load iterators
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.sos_token = loader.get_sos_token()
    opt.seq_length = loader.seq_length
    opt.video = 1
    if opt.glove is not None:
        opt.glove_npy = loader.build_glove(opt.glove)
    else:
        opt.glove_npy = None

    # set up models
    gen_model = MultiModalGenerator(opt)
    gen_model = gen_model.cuda()
    gen_model.train()
    gen_optimizer = utils.build_optimizer(gen_model.parameters(), opt)

    # loss functions
    crit = utils.LanguageModelCriterion()
    gan_crit = nn.BCELoss().cuda()

    # keep track of iteration
    g_iter = 0
    g_epoch = 0
    dis_flag = False
    joint_flag = False
    update_lr_flag = True

    # Load from checkpoint path
    infos = {'opt': opt}
    histories = {}
    infos['vocab'] = loader.get_vocab()
    if opt.g_start_from is not None:
        # Open old infos and check if models are compatible
        with open(os.path.join(opt.g_start_from, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # Load train/val histories
        with open(os.path.join(opt.g_start_from, 'histories.pkl')) as f:
            histories = cPickle.load(f)

        # Load generator
        g_start_epoch = opt.g_start_epoch
        g_model_path = os.path.join(opt.g_start_from, "gen_%s.pth" % g_start_epoch)
        g_optimizer_path = os.path.join(opt.g_start_from, "gen_optimizer_%s.pth" % g_start_epoch)
        assert os.path.isfile(g_model_path) and os.path.isfile(g_optimizer_path)
        gen_model.load_state_dict(torch.load(g_model_path))
        gen_optimizer.load_state_dict(torch.load(g_optimizer_path))
        if "latest" not in g_start_epoch and "best" != g_start_epoch:
            g_epoch = int(g_start_epoch) + 1
            g_iter = (g_epoch) * loader.split_size['train'] // opt.batch_size
        elif g_start_epoch == "best":
            g_epoch = infos['g_epoch_' + g_start_epoch] + 1
            g_iter = (g_epoch) * loader.split_size['train'] // opt.batch_size
        else:
            g_epoch = infos['g_epoch_' + g_start_epoch] + 1
            g_iter = infos['g_iter_' + g_start_epoch]
        print('loaded %s (epoch: %d iter: %d)' % (g_model_path, g_epoch, g_iter))

    infos['opt'] = opt
    loader.iterators = infos.get('g_iterators', loader.iterators)

    # misc
    best_val_score = infos.get('g_best_score', None)
    opt.seq_length = loader.seq_length
    opt.video = 1
    g_val_result_history = histories.get('g_val_result_history', {})
    g_loss_history = histories.get('g_loss_history', {})

    """ START TRAINING """
    while True:
        gc.collect()
        # set every epoch
        if update_lr_flag:
            # Assign the learning rate for generator
            if g_epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (g_epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(gen_optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if g_epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (g_epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                gen_model.ss_prob = opt.ss_prob

            # Start using previous sentence as context for generator (default: 10 epoch)
            if opt.g_context_epoch >= 0 and g_epoch >= opt.g_context_epoch:
                gen_model.use_context()

            update_lr_flag = False

        """ TRAIN GENERATOR """
        if not dis_flag:
            gen_model.train()

            # train generator
            start = time.time()

            gen_loss, wrapped, sent_num = train_generator(gen_model, gen_optimizer, crit, loader)
            end = time.time()

            # Print Info
            if g_iter % opt.losses_print_every == 0:
                print("g_iter {} (g_epoch {}), gen_loss = {:.3f}, time/batch = {:.3f}, num_sent = {} {}" \
                    .format(g_iter, g_epoch, gen_loss, end - start,sum(sent_num),sent_num))

            # Log Losses
            if g_iter % opt.losses_log_every == 0:
                g_loss = gen_loss
                g_loss_history[g_iter] = {'g_loss': g_loss, 'g_epoch': g_epoch}

            # Update the iteration
            g_iter += 1

            #########################
            # Evaluate & Save Model #
            #########################
            if wrapped:
                # evaluate model on dev set
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json,
                               'sample_max' : 1,
                               'language_eval': opt.language_eval,
                               'id' : opt.val_id,
                               'val_videos_use' : opt.val_videos_use,
                               'remove' : 1} # remove generated caption
                # eval_kwargs.update(vars(opt))

                val_loss, predictions, lang_stats, _ = eval_split(gen_model, crit, loader, eval_kwargs=eval_kwargs)

                if opt.language_eval == 1:
                    current_score = lang_stats['METEOR']
                else:
                    current_score = - val_loss
                g_val_result_history[g_epoch] = {'g_loss': val_loss, 'g_score': current_score, 'lang_stats': lang_stats}

                # Save the best generator model
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'gen_best.pth')
                    torch.save(gen_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'gen_optimizer_best.pth'))
                    infos['g_epoch_best'] = g_epoch
                    infos['g_iter_best'] = g_iter
                    infos['g_best_score'] = best_val_score
                    torch.save(gen_model.state_dict(), checkpoint_path)
                    print("best generator saved to {}".format(checkpoint_path))

                # Dump miscalleous informations and save
                infos['g_epoch_latest'] = g_epoch
                infos['g_iter_latest'] = g_iter
                infos['g_iterators'] = loader.iterators
                histories['g_val_result_history'] = g_val_result_history
                histories['g_loss_history'] = g_loss_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                # save the latest model
                if opt.save_checkpoint_every > 0 and g_epoch % opt.save_checkpoint_every == 0:
                    torch.save(gen_model.state_dict(), os.path.join(opt.checkpoint_path, 'gen_%d.pth'% g_epoch))
                    torch.save(gen_model.state_dict(), os.path.join(opt.checkpoint_path, 'gen_latest.pth'))
                    torch.save(gen_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'gen_optimizer_%d.pth'% g_epoch))
                    torch.save(gen_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'gen_optimizer_latest.pth'))
                    print("generator model saved to {} at epoch {}".format(opt.checkpoint_path, g_epoch))

                # update epoch and lr
                g_epoch += 1
                update_lr_flag = True
Beispiel #9
0
def train(rank, model, opt, optimizer=None):
    torch.manual_seed(opt.seed + rank)
    if opt.use_cuda:
        torch.cuda.manual_seed(opt.seed + rank)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(opt.start_from,
                             'infos_' + opt.load_model_id + '.pkl'),
                'rb') as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    ss_prob_history = infos.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_image_id = infos.get('split_image_id', loader.split_image_id)
    best_val_score = 0
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    update_lr_flag = True
    if opt.caption_model == 'show_tell':
        crit = utils.LanguageModelCriterion(opt)

    elif opt.caption_model == 'review_net':
        crit = utils.ReviewNetCriterion(opt)

    elif opt.caption_model == 'recurrent_fusion_model':
        crit = utils.ReviewNetEnsembleCriterion(opt)

    else:
        raise Exception("caption_model not supported: {}".format(
            opt.caption_model))

    if optimizer is None:
        if opt.optim == 'adam':
            optimizer = optim.Adam(model.parameters(),
                                   lr=opt.optim_lr,
                                   betas=(opt.optim_adam_beta1,
                                          opt.optim_adam_beta2),
                                   weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'rmsprop':
            optimizer = optim.RMSprop(model.parameters(),
                                      lr=opt.optim_lr,
                                      momentum=opt.optim_momentum,
                                      alpha=opt.optim_rmsprop_alpha,
                                      weight_decay=opt.weight_decay)
        elif opt.optim == 'sgd':
            optimizer = optim.SGD(model.parameters(),
                                  lr=opt.optim_lr,
                                  momentum=opt.optim_momentum,
                                  weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'adagrad':
            optimizer = optim.Adagrad(model.parameters(),
                                      lr=opt.optim_lr,
                                      lr_decay=opt.optim_lr_decay,
                                      weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'adadelta':
            optimizer = optim.Adadelta(model.parameters(),
                                       rho=opt.optim_rho,
                                       eps=opt.optim_epsilon,
                                       lr=opt.optim_lr,
                                       weight_decay=opt.optim_weight_decay)
        else:
            raise Exception("optim not supported: {}".format(opt.feature_type))

        # Load the optimizer
        if vars(opt).get('start_from', None) is not None:
            optimizer.load_state_dict(
                torch.load(
                    os.path.join(opt.start_from,
                                 'optimizer_' + opt.load_model_id + '.pth')))

    num_period_best = 0
    current_score = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.optim_lr * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.optim_lr
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')

        if opt.use_cuda:
            torch.cuda.synchronize()

        if opt.feature_type == 'feat_array':
            fc_feat_array = data['fc_feats_array']
            att_feat_array = data['att_feats_array']
            assert (len(fc_feat_array) == len(att_feat_array))
            for feat_id in range(len(fc_feat_array)):
                if opt.use_cuda:
                    fc_feat_array[feat_id] = Variable(
                        torch.from_numpy(fc_feat_array[feat_id]),
                        requires_grad=False).cuda()
                    att_feat_array[feat_id] = Variable(
                        torch.from_numpy(att_feat_array[feat_id]),
                        requires_grad=False).cuda()
                else:
                    fc_feat_array[feat_id] = Variable(torch.from_numpy(
                        fc_feat_array[feat_id]),
                                                      requires_grad=False)
                    att_feat_array[feat_id] = Variable(torch.from_numpy(
                        att_feat_array[feat_id]),
                                                       requires_grad=False)

            tmp = [data['labels'], data['masks'], data['top_words']]
            if opt.use_cuda:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False).cuda()
                    for _ in tmp
                ]
            else:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False)
                    for _ in tmp
                ]
            labels, masks, top_words = tmp

        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['top_words']
            ]
            if opt.use_cuda:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False).cuda()
                    for _ in tmp
                ]
            else:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False)
                    for _ in tmp
                ]
            fc_feats, att_feats, labels, masks, top_words = tmp

        optimizer.zero_grad()

        if opt.caption_model == 'show_tell':
            log_prob = model(fc_feats, att_feats, labels)  # (80L, 16L, 9488L)
            loss = crit(log_prob, labels[:, 1:], masks[:, 1:])

        elif opt.caption_model == 'review_net':
            log_prob, top_pred = model(fc_feats, att_feats,
                                       labels)  # (80L, 16L, 9488L)
            loss = crit(log_prob, labels[:, 1:], masks[:, 1:], top_pred,
                        top_words, opt.reason_weight)

        elif opt.caption_model == 'recurrent_fusion_model':
            log_prob, top_pred = model(fc_feat_array, att_feat_array,
                                       labels)  # (80L, 16L, 9488L)
            loss = crit(log_prob, labels[:, 1:], masks[:, 1:], top_pred,
                        top_words, opt.reason_weight)

        else:
            raise Exception("caption_model not supported: {}".format(
                opt.caption_model))

        loss.backward()

        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        if opt.use_cuda:
            torch.cuda.synchronize()
        end = time.time()

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if iteration % opt.losses_log_every == 0:
            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if iteration % opt.save_checkpoint_every == 0:
            # eval model
            eval_kwargs = {
                'eval_split': 'val',
                'dataset': opt.input_json,
                'caption_model': opt.caption_model,
                'reason_weight': opt.reason_weight,
                'guiding_l1_penality': opt.guiding_l1_penality,
                'use_cuda': opt.use_cuda,
                'feature_type': opt.feature_type,
                'rank': rank,
                'val_images_use': opt.val_images_use,
                'language_eval': 1
            }
            eval_kwargs.update(vars(opt))
            eval_kwargs['eval_split'] = 'val'
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False

            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
                num_period_best = 1
            else:
                num_period_best = num_period_best + 1

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_image_id'] = loader.split_image_id
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['val_result_history'] = val_result_history
            infos['loss_history'] = loss_history
            infos['lr_history'] = lr_history
            infos['ss_prob_history'] = ss_prob_history
            infos['vocab'] = loader.get_vocab()
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '_' + str(rank) + '.pkl'),
                    'wb') as f:
                cPickle.dump(infos, f)

            if best_flag:
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'model_' + opt.id + '_' + str(rank) + '-best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'optimizer_' + opt.id + '_' + str(rank) + '-best.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                print("model saved to {}".format(checkpoint_path))
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '_' + str(rank) + '-best.pkl'),
                        'wb') as f:
                    cPickle.dump(infos, f)

            if num_period_best >= opt.num_eval_no_improve:
                print('no improvement, exit')
                sys.exit()

        print(
            "rank {}, iter {}, (epoch {}), train loss: {}, learning rate: {}, current cider: {:.3f}, best cider: {:.3f}, time: {:.3f}"
            .format(rank, iteration, epoch, train_loss, opt.current_lr,
                    current_score, best_val_score, (end - start)))
        iteration += 1
        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #10
0
def train(opt):
    if vars(opt).get('start_from_en', None) is not None:
        opt.checkpoint_path_p = opt.start_from_en
        opt.id_p = opt.checkpoint_path_p.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path_p))
    else:
        opt.id_p = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path_p = os.path.join(opt.checkpoint_path_p, opt.id_p)

        if not os.path.exists(opt.checkpoint_path_p):
            os.makedirs(opt.checkpoint_path_p)
        print('Create folder: {}'.format(opt.checkpoint_path_p))

    # # Deal with feature things before anything
    # opt.use_att = utils.if_use_att(opt.caption_model)
    # if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    # loader = DataLoader_UP(opt)
    # opt.vocab_size = loader.vocab_size
    # if opt.use_rela == 1:
    #     opt.rela_dict_size = loader.rela_dict_size
    # opt.seq_length = loader.seq_length
    # use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path_p)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from_en is not None or opt.use_pretrained_setting == 1:
        # open old infos and check if models are compatible
        # with open(os.path.join(opt.checkpoint_path_p, 'infos.pkl')) as f:
        #     infos = cPickle.load(f)
        #     saved_model_opt = infos['opt']
        #     need_be_same = ["caption_model", "rnn_type", "rnn_size", "num_layers"]
        #     for checkme in need_be_same:
        #         assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        #
        #     # override and collect parameters
        #     if len(opt.input_fc_dir) == 0:
        #         opt.input_fc_dir = infos['opt'].input_fc_dir
        #         opt.input_att_dir = infos['opt'].input_att_dir
        #         opt.input_box_dir = infos['opt'].input_box_dir
        #         # opt.input_label_h5 = infos['opt'].input_label_h5
        #     if len(opt.input_json) == 0:
        #         opt.input_json = infos['opt'].input_json
        #     if opt.batch_size == 0:
        #         opt.batch_size = infos['opt'].batch_size
        #     if len(opt.id) == 0:
        #         opt.id = infos['opt'].id
        #         # opt.id = infos['opt'].id_p
        #
        #     ignore = ['checkpoint_path', "use_gfc", "use_isg", "ssg_dict_path", "input_json", "input_label_h5", "id",
        #               "batch_size", "start_from", "language_eval", "use_rela", "input_ssg_dir", "ssg_dict_path",
        #               "input_rela_dir", "use_spectral_norm", "beam_size", 'gpu', 'caption_model','use_att','max_epochs']
        #     beam_size = opt.beam_size
        #
        #     vocab = infos['vocab']  # ix -> word mapping
        #     opt.vocab = vocab
        #     opt.vocab_size = len(vocab)
        #     for k in vars(infos['opt']).keys():
        #         if k != 'model':
        #             if k not in ignore:
        #                 if k in vars(opt):
        #                     # assert vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent'
        #                     vars(opt).update({k: vars(infos['opt'])[k]})
        #                     print (vars(opt)[k] == vars(infos['opt'])[k], k + ' option not consistent, will be copyed from pretrained model')
        #                 else:
        #                     vars(opt).update({k: vars(infos['opt'])[k]})  # copy over options from model
        #     opt.input_fc_dir = 'data/cocobu_fc'
        #     opt.p_flag = 0

        # Load infos
        # opt.infos_path=os.path.join(opt.checkpoint_path_p, 'infos.pkl')
        opt.infos_path = os.path.join('data/fc/infos.pkl')
        with open(opt.infos_path) as f:
            infos = cPickle.load(f)

        # override and collect parameters
        if len(opt.input_fc_dir) == 0:
            opt.input_fc_dir = infos['opt'].input_fc_dir
            opt.input_att_dir = infos['opt'].input_att_dir
            opt.input_box_dir = infos['opt'].input_box_dir
            # opt.input_label_h5 = infos['opt'].input_label_h5
        if len(opt.input_json) == 0:
            opt.input_json = infos['opt'].input_json
        if opt.batch_size == 0:
            opt.batch_size = infos['opt'].batch_size
        if len(opt.id) == 0:
            opt.id = infos['opt'].id
            # opt.id = infos['opt'].id_p

        ignore = [
            'checkpoint_path', "use_gfc", "use_isg", "ssg_dict_path",
            "input_json", "input_label_h5", "id", "batch_size", "start_from",
            "language_eval", "use_rela", "input_ssg_dir", "ssg_dict_path",
            "input_rela_dir", "use_spectral_norm", "beam_size", 'gpu',
            'caption_model', 'self_critical_after', 'save_checkpoint_every'
        ]
        beam_size = opt.beam_size
        for k in vars(infos['opt']).keys():
            if k != 'model':
                if k not in ignore:
                    if k in vars(opt):
                        if not vars(opt)[k] == vars(infos['opt'])[k]:
                            print(
                                k +
                                ' option not consistent, copyed from pretrained model'
                            )
                            vars(opt).update({k: vars(infos['opt'])[k]})
                        else:
                            vars(opt).update({
                                k: vars(infos['opt'])[k]
                            })  # copy over options from model

        vocab = infos['vocab']  # ix -> word mapping
        opt.vocab = vocab
        opt.vocab_size = len(vocab)
        opt.input_fc_dir = 'data/cocobu_fc'

        if os.path.isfile(os.path.join(opt.checkpoint_path_p,
                                       'histories.pkl')):
            with open(os.path.join(opt.checkpoint_path_p,
                                   'histories.pkl')) as f:
                histories = cPickle.load(f)

    # Create the Data Loader instance
    loader = DataLoader_UP(opt)
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)
    # When eval using provided pretrained model, the vocab may be different from what you have in your cocotalk.json
    # So make sure to use the vocab in infos file.
    try:  # if use pretrained model
        loader.ix_to_word = infos['vocab']
    except:  # if train from scratch
        infos = json.load(open(opt.input_json))
        opt.ix_to_word = infos['ix_to_word']
        opt.vocab_size = len(opt.ix_to_word)

    # iteration = infos.get('iter', 0)
    # epoch = infos.get('epoch', 0)
    iteration = 0
    epoch = 0

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Setup the model
    try:
        opt.caption_model = opt.caption_model_zh
    except:
        opt.caption_model = opt.caption_model
    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        if getattr(opt, 'use_ssg', 0) == 1:
            if getattr(opt, 'use_isg', 0) == 1:
                tmp = [
                    data['fc_feats'], data['labels'], data['masks'],
                    data['att_feats'], data['att_masks'],
                    data['isg_rela_matrix'], data['isg_rela_masks'],
                    data['isg_obj'], data['isg_obj_masks'], data['isg_attr'],
                    data['isg_attr_masks'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks']
                ]

                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, labels, masks, att_feats, att_masks, \
                isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \
                ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp

                # image graph domain
                isg_data = {}
                isg_data['att_feats'] = att_feats
                isg_data['att_masks'] = att_masks

                isg_data['isg_rela_matrix'] = isg_rela_matrix
                isg_data['isg_rela_masks'] = isg_rela_masks
                isg_data['isg_obj'] = isg_obj
                isg_data['isg_obj_masks'] = isg_obj_masks
                isg_data['isg_attr'] = isg_attr
                isg_data['isg_attr_masks'] = isg_attr_masks
                # text graph domain
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
            else:
                tmp = [
                    data['fc_feats'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks'], data['labels'], data['masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr

                isg_data = None
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        if not sc_flag:
            # loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:])
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   isg_data,
                                                   ssg_data,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, isg_data,
                                              ssg_data, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \
                      .format(opt.id_p, iteration, epoch, train_loss, end - start))
            else:
                print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \
                      .format(opt.id_p, iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            # if (iteration % 100 == 0) and (iteration != 0):
            # eval model
            if use_rela:
                eval_kwargs = {
                    'split': 'val',
                    'dataset': opt.input_json,
                    'use_real': 1
                }
            else:
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split_fc(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path_p = os.path.join(opt.checkpoint_path_p,
                                                 'model.pth')
                torch.save(model.state_dict(), checkpoint_path_p)
                print("model saved to {}".format(checkpoint_path_p))
                optimizer_path = os.path.join(opt.checkpoint_path_p,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path_p, 'infos.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path_p, 'histories.pkl'),
                          'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path_p = os.path.join(opt.checkpoint_path_p,
                                                     'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path_p)
                    print("model saved to {}".format(checkpoint_path_p))
                    with open(
                            os.path.join(opt.checkpoint_path_p,
                                         'infos-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #11
0
def train(opt):
    print("=================Training Information==============")
    print("start from {}".format(opt.start_from))
    print("box from {}".format(opt.input_box_dir))
    print("attributes from {}".format(opt.input_att_dir))
    print("features from {}".format(opt.input_fc_dir))
    print("batch size ={}".format(opt.batch_size))
    print("#GPU={}".format(torch.cuda.device_count()))
    print("Caption model {}".format(opt.caption_model))
    print("refine aoa {}".format(opt.refine_aoa))
    print("Number of aoa module {}".format(opt.aoa_num))
    print("Self Critic After  {}".format(opt.self_critical_after))
    print("learning_rate_decay_every {}".format(opt.learning_rate_decay_every))

    # use more data to fine tune the model for better challeng results. We dont use it
    if opt.use_val or opt.use_test:
        print("+++++++++++It is a refining training+++++++++++++++")
        print("===========Val is {} used for training ===========".format(
            '' if opt.use_val else 'not'))
        print("===========Test is {} used for training ===========".format(
            '' if opt.use_test else 'not'))
    print("=====================================================")

    # set more detail name of checkpoint paths
    checkpoint_path_suffix = "_bs{}".format(opt.batch_size)
    if opt.use_warmup:
        checkpoint_path_suffix += "_warmup"
    if torch.cuda.device_count() > 1:
        checkpoint_path_suffix += "_gpu{}".format(torch.cuda.device_count())

    if opt.checkpoint_path.endswith('_rl'):
        opt.checkpoint_path = opt.checkpoint_path[:
                                                  -3] + checkpoint_path_suffix + '_rl'
    else:
        opt.checkpoint_path += checkpoint_path_suffix
    print("Save model to {}".format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    acc_steps = getattr(opt, 'acc_steps', 1)
    name_append = opt.name_append
    if len(name_append) > 0 and name_append[0] != '-':
        name_append = '_' + name_append

    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.losses_log_every = len(loader.split_ix['train']) // opt.batch_size
    print("Evaluate on each {} iterations".format(opt.losses_log_every))
    if opt.write_summary:
        print("write summary to {}".format(opt.checkpoint_path))
        tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}

    # load  checkpoint
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infos_path = os.path.join(opt.start_from,
                                  'infos' + name_append + '.pkl')
        print("Load model information {}".format(infos_path))
        with open(infos_path, 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]

            # this sanity check may not work well, and comment it if necessary
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], \
                    "Command line argument and saved model disagree on '%s' " % checkme

        histories_path = os.path.join(opt.start_from,
                                      'histories' + name_append + '.pkl')
        if os.path.isfile(histories_path):
            with open(histories_path, 'rb') as f:
                histories = utils.pickle_load(f)
    else:  # start from scratch
        print("==============================================")
        print("Initialize training process from all begining")
        print("==============================================")
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()

    infos['opt'] = opt
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    print("==================start from {} iterations -- {} epoch".format(
        iteration, epoch))
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    start_Img_idx = loader.iterators['train']
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_epoch = infos.get('best_epoch', None)
        best_cider = infos.get('best_val_score', 0)
        print("========best history val cider score: {} in epoch {}=======".
              format(best_val_score, best_epoch))

    #  sanity check for the saved model name has a correct index
    if opt.name_append.isdigit() and int(opt.name_append) < 100:
        assert int(
            opt.name_append
        ) - epoch == 1, "dismatch in the model index and the real epoch number"
        epoch += 1
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab

    if torch.cuda.device_count() > 1:
        dp_model = torch.nn.DataParallel(model)
    else:
        dp_model = model
    lw_model = LossWrapper1(model, opt)  # wrap loss into model
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'aoa'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer_path = os.path.join(opt.start_from,
                                      'optimizer' + name_append + '.pth')
        if os.path.isfile(optimizer_path):
            print("Loading optimizer............")
            optimizer.load_state_dict(torch.load(optimizer_path))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '_' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("Save model state to {}".format(checkpoint_path))

        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        print("Save model optimizer to {}".format(optimizer_path))

        with open(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
            print("Save training information to {}".format(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append))))

        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories' + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)
                print("Save training historyes to {}".format(
                    os.path.join(opt.checkpoint_path,
                                 'histories' + '%s.pkl' % (append))))

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor * opt.refine_lr_decay
                    else:
                        opt.current_lr = opt.learning_rate
                    infos['current_lr'] = opt.current_lr
                    print("Current Learning Rate is: {}".format(
                        opt.current_lr))
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False
            print("{}th Epoch Training starts now!".format(epoch))
            with tqdm(total=len(loader.split_ix['train']),
                      initial=start_Img_idx) as pbar:
                for i in range(start_Img_idx, len(loader.split_ix['train']),
                               opt.batch_size):
                    start = time.time()
                    if (opt.use_warmup
                            == 1) and (iteration < opt.noamopt_warmup):
                        opt.current_lr = opt.learning_rate * (
                            iteration + 1) / opt.noamopt_warmup
                        utils.set_lr(optimizer, opt.current_lr)
                    # Load data from train split (0)
                    data = loader.get_batch('train')
                    # print('Read data:', time.time() - start)

                    if (iteration % acc_steps == 0):
                        optimizer.zero_grad()

                    torch.cuda.synchronize()
                    start = time.time()
                    tmp = [
                        data['fc_feats'], data['att_feats'],
                        data['flag_feats'], data['labels'], data['masks'],
                        data['att_masks']
                    ]
                    tmp = [_ if _ is None else _.cuda() for _ in tmp]
                    fc_feats, att_feats, flag_feats, labels, masks, att_masks = tmp

                    model_out = dp_lw_model(fc_feats, att_feats, flag_feats,
                                            labels, masks, att_masks,
                                            data['gts'],
                                            torch.arange(0, len(data['gts'])),
                                            sc_flag)

                    loss = model_out['loss'].mean()
                    loss_sp = loss / acc_steps

                    loss_sp.backward()
                    if (iteration + 1) % acc_steps == 0:
                        utils.clip_gradient(optimizer, opt.grad_clip)
                        optimizer.step()
                    torch.cuda.synchronize()
                    train_loss = loss.item()
                    end = time.time()
                    if not sc_flag:
                        pbar.set_description(
                            "iter {:8} (epoch {:2}), train_loss = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch, train_loss, end - start))
                    else:
                        pbar.set_description(
                            "iter {:8} (epoch {:2}), avg_reward = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch,
                                    model_out['reward'].mean(), end - start))

                    # Update the iteration and epoch
                    iteration += 1
                    pbar.update(opt.batch_size)
                    if data['bounds']['wrapped']:
                        epoch += 1
                        epoch_done = True
                        # save after each epoch
                        save_checkpoint(model, infos, optimizer)
                        if epoch > 15:  # To save memory, you can comment this part
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append=str(epoch))
                        print(
                            "====================================================="
                        )
                        print(
                            "======Best Cider = {} in epoch {}: iter {}!======"
                            .format(best_val_score, best_epoch,
                                    infos.get('best_itr', None)))
                        print(
                            "====================================================="
                        )

                    # Write training history into summary
                    if (iteration % opt.losses_log_every
                            == 0) and opt.write_summary:
                        # if (iteration % 10== 0) and opt.write_summary:
                        add_summary_value(tb_summary_writer, 'loss/train_loss',
                                          train_loss, iteration)
                        if opt.noamopt:
                            opt.current_lr = optimizer.rate()
                        elif opt.reduce_on_plateau:
                            opt.current_lr = optimizer.current_lr
                        add_summary_value(tb_summary_writer,
                                          'hyperparam/learning_rate',
                                          opt.current_lr, iteration)
                        add_summary_value(
                            tb_summary_writer,
                            'hyperparam/scheduled_sampling_prob',
                            model.ss_prob, iteration)
                        if sc_flag:
                            add_summary_value(tb_summary_writer, 'avg_reward',
                                              model_out['reward'].mean(),
                                              iteration)

                        loss_history[
                            iteration] = train_loss if not sc_flag else model_out[
                                'reward'].mean()
                        lr_history[iteration] = opt.current_lr
                        ss_prob_history[iteration] = model.ss_prob

                    # update infos
                    infos['iter'] = iteration
                    infos['epoch'] = epoch
                    infos['iterators'] = loader.iterators
                    infos['split_ix'] = loader.split_ix

                    # make evaluation on validation set, and save model
                    # unnecessary to eval from the beginning
                    if (iteration % opt.save_checkpoint_every
                            == 0) and eval_ and epoch > 3:
                        # eval model
                        model_path = os.path.join(
                            opt.checkpoint_path,
                            'model_itr%s.pth' % (iteration))
                        if opt.use_val and not opt.use_test:
                            val_split = 'test'
                        if not opt.use_val:
                            val_split = 'val'
                        # val_split = 'val'

                        eval_kwargs = {
                            'split': val_split,
                            'dataset': opt.input_json,
                            'model': model_path
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # Write validation result into summary
                        if opt.write_summary:
                            add_summary_value(tb_summary_writer,
                                              'loss/validation loss', val_loss,
                                              iteration)

                            if lang_stats is not None:
                                bleu_dict = {}
                                for k, v in lang_stats.items():
                                    if 'Bleu' in k:
                                        bleu_dict[k] = v
                                if len(bleu_dict) > 0:
                                    tb_summary_writer.add_scalars(
                                        'val/Bleu', bleu_dict, epoch)

                                for k, v in lang_stats.items():
                                    if 'Bleu' not in k:
                                        add_summary_value(
                                            tb_summary_writer, 'val/' + k, v,
                                            iteration)
                        val_result_history[iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # Save model if is improving on validation result
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            infos['best_epoch'] = epoch
                            infos['best_itr'] = iteration
                            best_flag = True

                        # Dump miscalleous informations
                        infos['best_val_score'] = best_val_score
                        histories['val_result_history'] = val_result_history
                        histories['loss_history'] = loss_history
                        histories['lr_history'] = lr_history
                        histories['ss_prob_history'] = ss_prob_history

                        save_checkpoint(model, infos, optimizer, histories)
                        if opt.save_history_ckpt:
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append=str(iteration))

                        if best_flag:
                            best_epoch = epoch
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append='best')
                            print(
                                "update best model at {} iteration--{} epoch".
                                format(iteration, epoch))
                    # reset
                    start_Img_idx = 0
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                print("epoch {} break all".format(epoch))
                save_checkpoint(model, infos, optimizer)
                # save_checkpoint(model, infos, optimizer, append=str(epoch))
                tb_summary_writer.close()
                print("============{} Training Done !==============".format(
                    'Refine' if opt.use_test or opt.use_val else ''))
                break
    except (RuntimeError, KeyboardInterrupt):  # KeyboardInterrupt
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer, append='interrupt')
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()

    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    total_loss = 0
    times = 0
    while True:
        if epoch_done:
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        times += 1

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        total_loss = total_loss + train_loss
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            # epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration % opt.save_checkpoint_every == 0):
        if data['bounds']['wrapped']:
            epoch += 1
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': False
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats
                f = open('train_log_%s.txt' % opt.id, 'a')
                f.write(
                    'Epoch {}: | Date: {} | TrainLoss: {} | ValLoss: {} | Score: {}'
                    .format(epoch, str(datetime.now()),
                            str(total_loss / times), str(val_loss),
                            str(current_score)))
                f.write('\n')
                f.close()
                print('-------------------wrote to log file')
                total_loss = 0
                times = 0
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                # print(str(infos['best_val_score']))
                print("model saved to {}".format(checkpoint_path))
                if opt.save_history_ckpt:
                    checkpoint_path = os.path.join(
                        opt.checkpoint_path, 'model-%d.pth' % (iteration))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    utils.pickle_dump(infos, f)
                if opt.save_history_ckpt:
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_' + opt.id + '-%d.pkl' % (iteration)),
                            'wb') as f:
                        cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    utils.pickle_dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        utils.pickle_dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #13
0
def train(opt):
    set_seed(opt.seed)
    save_folder = build_floder(opt)  # './save/debug_2020-10-26_08-53-55'  创建结果文件夹
    logger = create_logger(save_folder, 'train.log')   # 创建logger对象
    tf_writer = SummaryWriter(os.path.join(save_folder, 'tf_summary'))   # tensorboardX

    if not opt.start_from:
        backup_envir(save_folder)   # backup是备份的意思
        logger.info('backup evironment completed !')

    saved_info = {'best': {}, 'last': {}, 'history': {}, 'eval_history': {}}

    # continue training
    if opt.start_from:
        opt.pretrain = False
        infos_path = os.path.join(save_folder, 'info.json')
        with open(infos_path) as f:
            logger.info('Load info from {}'.format(infos_path))
            saved_info = json.load(f)
            prev_opt = saved_info[opt.start_from_mode[:4]]['opt']

            exclude_opt = ['start_from', 'start_from_mode', 'pretrain']
            for opt_name in prev_opt.keys():
                if opt_name not in exclude_opt:
                    vars(opt).update({opt_name: prev_opt.get(opt_name)})
                if prev_opt.get(opt_name) != vars(opt).get(opt_name):
                    logger.info('Change opt {} : {} --> {}'.format(opt_name, prev_opt.get(opt_name),
                                                                   vars(opt).get(opt_name)))
        opt.feature_dim = opt.raw_feature_dim

    train_dataset = PropSeqDataset(opt.train_caption_file,
                                   opt.visual_feature_folder,
                                   opt.dict_file, True, opt.train_proposal_type,
                                   logger, opt)

    val_dataset = PropSeqDataset(opt.val_caption_file,
                                 opt.visual_feature_folder,
                                 opt.dict_file, False, 'gt',
                                 logger, opt)

    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size,
                              shuffle=True, num_workers=opt.nthreads, collate_fn=collate_fn)

    val_loader = DataLoader(val_dataset, batch_size=opt.batch_size,
                            shuffle=False, num_workers=opt.nthreads, collate_fn=collate_fn)

    epoch = saved_info[opt.start_from_mode[:4]].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode[:4]].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode[:4]].get('best_val_score', -1e5)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    opt.current_lr = vars(opt).get('current_lr', opt.lr)
    opt.vocab_size = train_loader.dataset.vocab_size

    # Build model
    model = EncoderDecoder(opt)  # 核心代码
    model.train()

    # Recover the parameters
    if opt.start_from and (not opt.pretrain):  #start_from = '' pretrain = False
        if opt.start_from_mode == 'best':
            model_pth = torch.load(os.path.join(save_folder, 'model-best-CE.pth'))
        elif opt.start_from_mode == 'best-RL':
            model_pth = torch.load(os.path.join(save_folder, 'model-best-RL.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(os.path.join(save_folder, 'model-last.pth'))
        logger.info('Loading pth from {}, iteration:{}'.format(save_folder, iteration))
        model.load_state_dict(model_pth['model'])

    # Load the pre-trained model
    if opt.pretrain and (not opt.start_from):
        logger.info('Load pre-trained parameters from {}'.format(opt.pretrain_path))
        if torch.cuda.is_available():
            model_pth = torch.load(opt.pretrain_path)
        else:
            model_pth = torch.load(opt.pretrain_path, map_location=torch.device('cpu'))
        model.load_state_dict(model_pth['model'])

    if torch.cuda.is_available():
        model.cuda()

    if opt.optimizer_type == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)  # weight_decay = 0
    else:
        optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)

    if opt.start_from:
        optimizer.load_state_dict(model_pth['optimizer'])

    # print the args for debugging
    print_opt(opt, model, logger)
    print_alert_message('Strat training !', logger)

    loss_sum = np.zeros(3)  # (3,)  3 for loss, sample_score, greedy_score
    bad_video_num = 0
    start = time.time()

    # Epoch-level iteration
    while True:
        if True:
            # lr decay
            if epoch > opt.learning_rate_decay_start >= 0:  # learning_rate_decay_start=8  
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate ** frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            utils.set_lr(optimizer, opt.current_lr)
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    # log information
    folder_id = 'log_result'
    file_id = 'twin_show_attend_tell'
    log_file_name = os.path.join(folder_id, file_id + '.txt')
    log_file = open(log_file_name, 'w')

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()
    back_model = models.setup(opt, reverse=True)  # True for twin-net
    back_model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()
    back_model.train()

    crit = utils.LanguageModelCriterion()  # define the loss criterion
    all_param = chain(model.parameters(), back_model.parameters())
    optimizer = optim.Adam(all_param,
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        # flip the masks and labels for twin-net
        reverse_labels = np.flip(data['labels'], 1).copy()
        reverse_masks = np.flip(data['masks'], 1).copy()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'],
            reverse_labels, data['masks'], reverse_masks
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, reverse_labels, masks, reverse_masks = tmp

        optimizer.zero_grad()
        out, states = model(fc_feats, att_feats, labels)
        back_out, back_states = back_model(fc_feats, att_feats, reverse_labels)
        idx = [i for i in range(back_states.size()[1] - 1, -1, -1)]
        # print (back_states.size(), back_states.size()[1])
        # print (type(idx))
        # print (idx)

        idx = torch.LongTensor(idx)
        idx = Variable(idx).cuda()
        invert_backstates = back_states.index_select(1, idx)

        # print (states.size(), back_states.size())

        # check if the back states are inverted
        # back = back_states.index_select(1, Variable(torch.LongTensor([2])).cuda())
        # forw = invert_backstates.index_select(1, Variable(torch.LongTensor([14])).cuda())
        # print (forw, back)
        # print (back_states.index_select(1, Variable(torch.LongTensor([3])).cuda()))
        # print (invert_backstates.size())

        loss = crit(out, labels[:, 1:],
                    masks[:, 1:])  # compute using the defined criterion

        back_loss = crit(back_out, reverse_labels[:, :-1],
                         reverse_masks[:, :-1])

        invert_backstates = invert_backstates.detach()
        l2_loss = ((states - invert_backstates)**2).mean()

        all_loss = loss + 1.5 * l2_loss + back_loss

        all_loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        # store the relevant values
        train_l2_loss = l2_loss.data[0]
        train_loss = loss.data[0]
        train_all_loss = all_loss.data[0]
        train_back_loss = back_loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, l2_loss = {:.3f}, back_loss = {:.3f}, all_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, train_l2_loss, train_back_loss, train_all_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'l2_loss', train_l2_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'all_loss',
                                  train_all_loss, iteration)
                add_summary_value(tf_summary_writer, 'back_loss',
                                  train_back_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            log_line = 'Epoch [%d], Step [%d], all loss: %f,back_loss %f,train_l2_loss %f, train_loss %f, time %f ' % (
                epoch, iteration, train_all_loss, train_back_loss,
                train_l2_loss, train_loss, time.clock())
            log_file.write(log_line + '\n')

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #15
0
def train(opt):
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:  # resume training
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()

    infos['opt'] = opt
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    
    # setup model and optimizer
    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    dp_lw_model.train()
    optimizer = utils.build_optimizer(filter(lambda p: p.requires_grad, model.parameters()), opt)
    if vars(opt).get('start_from', None) is not None:  # Load the optimizer
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    done_flag = True # True when the first iteration, warm-up done and epoch done
    try:
        while True:
            warmup_n = opt.warmup_n
            if iteration <= warmup_n:
                opt.current_lr = iteration * opt.learning_rate / warmup_n
                utils.set_lr(optimizer, opt.current_lr)
                if iteration == warmup_n:
                    done_flag = True

            if done_flag and iteration >= warmup_n:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate  ** frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                if iteration == warmup_n:
                    done_flag = False

            if done_flag:
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob                
                done_flag = False
            
            start = time.time()
            data = loader.get_batch('train')
            if iteration % 5 == 0:
                print('Read data:', time.time() - start)
            if iteration % 5 == 0:
                print('learning rate: {}'.format(opt.current_lr))
            torch.cuda.synchronize()

            start = time.time()
            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['trip_pred'],\
                  data['obj_dist'], data['obj_box'], data['rel_ind'], data['pred_fmap'], data['pred_dist'],\
                  data['gpn_obj_ind'], data['gpn_pred_ind'], data['gpn_nrel_ind'], data['gpn_pool_mtx']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks, trip_pred, obj_dist, obj_box, rel_ind, pred_fmap, pred_dist,\
            gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_pool_mtx = tmp

            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), trip_pred,\
                                    obj_dist, obj_box, rel_ind, pred_fmap, pred_dist, gpn_obj_ind, gpn_pred_ind, gpn_nrel_ind, gpn_pool_mtx)
            
            gpn_loss = model_out['gpn_loss'].mean() if model_out['gpn_loss'] is not None else None
            if model_out['lang_loss'] is not None:
                lang_loss = model_out['lang_loss'].mean()
                if gpn_loss is not None:
                    loss = lang_loss + gpn_loss
                else:
                    loss = lang_loss  # no gpn module

            loss.backward()
            utils.clip_gradient_norm(optimizer, 10.)
            optimizer.step()

            gpn_l = gpn_loss.item() if gpn_loss is not None else 0
            lang_l = lang_loss.item() if lang_loss is not None else 0
            train_loss = loss.item()
            torch.cuda.synchronize()
            
            end = time.time()
            if iteration % 5 == 0:
                print("iter {} (ep {}), gpn_loss = {:.3f}, lang_loss = {:.3f}, loss = {:.3f}, time/b = {:.3f}" \
                    .format(iteration, epoch, gpn_l, lang_l, train_loss, end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                done_flag = True
            
            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tb_summary_writer, 'gpn_loss', gpn_l, iteration)
                add_summary_value(tb_summary_writer, 'lang_loss', lang_l, iteration)
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)

                loss_history[iteration] = train_loss
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            
            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0) or (epoch >= opt.max_epochs and opt.max_epochs != -1):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))

                val_loss = eval_utils.eval_split(dp_model, lw_model.crit, loader, eval_kwargs, opt=opt, val_model=model)

                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'language validation loss', val_loss, iteration)
                val_result_history[iteration] = {'loss': val_loss}

                # Save model if is improving on validation result
                current_score = - val_loss # still using the language validation loss

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                # Stop if reaching max epochs
                if epoch >= opt.max_epochs and opt.max_epochs != -1:
                    #save_checkpoint(model, infos, optimizer, append='last')
                    break
    except (RuntimeError, KeyboardInterrupt):
        stack_trace = traceback.format_exc()
        print(stack_trace)
Beispiel #16
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.vocab = loader.get_vocab()

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        # with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
        with open(os.path.join(opt.start_from, 'infos_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
        #     with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.start_from.split('/')[-1]+'.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))


    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path, 'model%s.pth' %(append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer%s.pth' %(append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'%s.pkl' %(append)), 'wb') as f:
                utils.pickle_dump(histories, f)
    # pdb.set_trace()
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                    init_scorer(opt.cached_tokens)

                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            # pdb.set_trace()
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks'], data['sents_mask']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks, sents_mask = tmp
            box_inds = None
                
            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, box_inds, epoch, sents_mask)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward', model_out['reward'].mean(), iteration)

                loss_history[iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            
            # make evaluation on validation set, and save model
            # eval model
            # eval_kwargs = {'split': 'val',
            #                 'dataset': opt.input_json}
            # eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(
            #     dp_model, lw_model.crit, loader, eval_kwargs)

            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model, infos, optimizer, append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Beispiel #17
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'),'rb') as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'),'rb') as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()
    if opt.multi_gpu:
        model=nn.DataParallel(model)
    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        
        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #18
0
def train(opt):
    if vars(opt).get('start_from', None) is not None:
        opt.checkpoint_path = opt.start_from
        opt.id = opt.checkpoint_path.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path))
    else:
        opt.id = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)

        if not os.path.exists(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        print('Create folder: {}'.format(opt.checkpoint_path))

    # Write YAML file
    with io.open(opt.checkpoint_path + '/opts.yaml', 'w',
                 encoding='utf8') as outfile:
        yaml.dump(opt, outfile, default_flow_style=False, allow_unicode=True)

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_GAN(opt)
    loader_i2t = DataLoader_UP(opt)

    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        try:
            with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
                infos = cPickle.load(f)
                saved_model_opt = infos['opt']
                need_be_same = [
                    "caption_model", "rnn_type", "rnn_size", "num_layers"
                ]
                for checkme in need_be_same:
                    assert vars(saved_model_opt)[checkme] == vars(
                        opt
                    )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

            if os.path.isfile(
                    os.path.join(opt.checkpoint_path, 'histories.pkl')):
                with open(os.path.join(opt.checkpoint_path,
                                       'histories.pkl')) as f:
                    histories = cPickle.load(f)
        except:
            print("Can not load infos.pkl")

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # opt.caption_model = 'up_gtssg_sep_self_att_sep'
    opt.caption_model = opt.caption_model_to_replace
    model = models.setup(opt).cuda()
    print('### Model summary below###\n {}\n'.format(str(model)))
    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('model parameter:{}'.format(model_params))

    model.eval()

    train_loss = 0
    update_lr_flag = True
    fake_A_pool_obj = utils.ImagePool(opt.pool_size)
    fake_A_pool_rel = utils.ImagePool(opt.pool_size)
    fake_A_pool_atr = utils.ImagePool(opt.pool_size)

    fake_B_pool_obj = utils.ImagePool(opt.pool_size)
    fake_B_pool_rel = utils.ImagePool(opt.pool_size)
    fake_B_pool_atr = utils.ImagePool(opt.pool_size)

    netD_A_obj = GAN_init_D(opt, Discriminator(opt),
                            type='netD_A_obj').cuda().train()
    netD_A_rel = GAN_init_D(opt, Discriminator(opt),
                            type='netD_A_rel').cuda().train()
    netD_A_atr = GAN_init_D(opt, Discriminator(opt),
                            type='netD_A_atr').cuda().train()

    netD_B_obj = GAN_init_D(opt, Discriminator(opt),
                            type='netD_B_obj').cuda().train()
    netD_B_rel = GAN_init_D(opt, Discriminator(opt),
                            type='netD_B_rel').cuda().train()
    netD_B_atr = GAN_init_D(opt, Discriminator(opt),
                            type='netD_B_atr').cuda().train()

    netG_A_obj = GAN_init_G(opt, Generator(opt),
                            type='netG_A_obj').cuda().train()
    netG_A_rel = GAN_init_G(opt, Generator(opt),
                            type='netG_A_rel').cuda().train()
    netG_A_atr = GAN_init_G(opt, Generator(opt),
                            type='netG_A_atr').cuda().train()

    netG_B_obj = GAN_init_G(opt, Generator(opt),
                            type='netG_B_obj').cuda().train()
    netG_B_rel = GAN_init_G(opt, Generator(opt),
                            type='netG_B_rel').cuda().train()
    netG_B_atr = GAN_init_G(opt, Generator(opt),
                            type='netG_B_atr').cuda().train()

    optimizer_G = utils.build_optimizer(
        itertools.chain(netG_A_obj.parameters(), netG_B_obj.parameters(),
                        netG_A_rel.parameters(), netG_B_rel.parameters(),
                        netG_A_atr.parameters(), netG_B_atr.parameters()), opt)
    optimizer_D = utils.build_optimizer(
        itertools.chain(netD_A_obj.parameters(), netD_B_obj.parameters(),
                        netD_A_rel.parameters(), netD_B_rel.parameters(),
                        netD_A_atr.parameters(), netD_B_atr.parameters()), opt)

    criterionGAN = GANLoss(opt.gan_mode).cuda()  # define GAN loss.
    criterionCycle = torch.nn.L1Loss()
    criterionIdt = torch.nn.L1Loss()

    optimizers = []
    optimizers.append(optimizer_G)
    optimizers.append(optimizer_D)
    schedulers = [get_scheduler(opt, optimizer) for optimizer in optimizers]
    current_lr = optimizers[0].param_groups[0]['lr']
    train_num = 0
    update_lr_flag = True

    while True:
        if update_lr_flag and opt.current_lr >= 1e-4:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            if opt.current_lr >= 1e-4:
                utils.set_lr(optimizer, opt.current_lr)
            else:
                utils.set_lr(optimizer, 1e-4)
            update_lr_flag = False
        """
        Show the percentage of data loader
        """
        if train_num > loader.max_index:
            train_num = 0
        train_num = train_num + 1
        train_precentage = float(train_num) * 100 / float(loader.max_index)
        """
        Start training
        """
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['isg_feats'][:, 0, :], data['isg_feats'][:, 1, :],
            data['isg_feats'][:, 2, :], data['ssg_feats'][:, 0, :],
            data['ssg_feats'][:, 1, :], data['ssg_feats'][:, 2, :]
        ]

        tmp = [
            _ if _ is None else torch.from_numpy(_).float().cuda() for _ in tmp
        ]

        real_A_obj, real_A_rel, real_A_atr, real_B_obj, real_B_rel, real_B_atr = tmp

        iteration += 1

        fake_B_rel = netG_A_rel(real_A_rel)
        rec_A_rel = netG_B_rel(fake_B_rel)
        idt_B_rel = netG_B_rel(real_A_rel)

        fake_A_rel = netG_B_rel(real_B_rel)
        rec_B_rel = netG_A_rel(fake_A_rel)
        idt_A_rel = netG_A_rel(real_B_rel)

        # Obj
        fake_B_obj = netG_A_obj(real_A_obj)
        rec_A_obj = netG_B_obj(fake_B_obj)
        idt_B_obj = netG_B_obj(real_A_obj)

        fake_A_obj = netG_B_obj(real_B_obj)
        rec_B_obj = netG_A_obj(fake_A_obj)
        idt_A_obj = netG_A_obj(real_B_obj)

        # Atr
        fake_B_atr = netG_A_atr(real_A_atr)
        rec_A_atr = netG_B_atr(fake_B_atr)
        idt_B_atr = netG_B_atr(real_A_atr)

        fake_A_atr = netG_B_atr(real_B_atr)
        rec_B_atr = netG_A_atr(fake_A_atr)
        idt_A_atr = netG_A_atr(real_B_atr)

        domain_A = [
            real_A_obj, real_A_rel, real_A_atr, fake_A_obj, fake_A_rel,
            fake_A_atr, rec_A_obj, rec_A_rel, rec_A_atr, idt_A_obj, idt_A_rel,
            idt_A_atr
        ]
        domain_B = [
            real_B_obj, real_B_rel, real_B_atr, fake_B_obj, fake_B_rel,
            fake_B_atr, rec_B_obj, rec_B_rel, rec_B_atr, idt_B_obj, idt_B_rel,
            idt_B_atr
        ]
        # G_A and G_B
        utils.set_requires_grad([
            netD_A_obj, netD_A_rel, netD_A_atr, netD_B_obj, netD_B_rel,
            netD_B_atr
        ], False)  # Ds require no gradients when optimizing Gs
        optimizer_G.zero_grad()  # set G_A and G_B's gradients to zero

        loss_G = cycle_GAN_backward_G(opt, criterionGAN, criterionCycle,
                                      criterionIdt, netG_A_obj, netG_A_rel,
                                      netG_A_atr, netG_B_obj, netG_B_rel,
                                      netG_B_atr, netD_A_obj, netD_A_rel,
                                      netD_A_atr, netD_B_obj, netD_B_rel,
                                      netD_B_atr, domain_A, domain_B)
        loss_G.backward()
        optimizer_G.step()

        # D_A and D_B
        utils.set_requires_grad([
            netD_A_obj, netD_A_rel, netD_A_atr, netD_B_obj, netD_B_rel,
            netD_B_atr
        ], True)
        optimizer_D.zero_grad()  # set D_A and D_B's gradients to zero
        loss_D_A = cycle_GAN_backward_D(opt, fake_B_pool_obj, fake_B_pool_rel,
                                        fake_B_pool_atr, netD_A_obj,
                                        netD_A_rel, netD_A_atr, criterionGAN,
                                        real_B_obj, real_B_rel, real_B_atr,
                                        fake_B_obj, fake_B_rel, fake_B_atr)
        loss_D_A.backward()
        loss_D_B = cycle_GAN_backward_D(opt, fake_A_pool_obj, fake_A_pool_rel,
                                        fake_A_pool_atr, netD_B_obj,
                                        netD_B_rel, netD_B_atr, criterionGAN,
                                        real_A_obj, real_A_rel, real_A_atr,
                                        fake_A_obj, fake_A_rel, fake_A_atr)
        loss_D_B.backward()
        optimizer_D.step()  # update D_A and D_B's weights

        end = time.time()
        train_loss_G = loss_G.item()
        train_loss_D_A = loss_D_A.item()
        train_loss_D_B = loss_D_B.item()
        print(
            "{}/{:.1f}/{}/{}|train_loss={:.3f}|train_loss_G={:.3f}|train_loss_D_A={:.3f}|train_loss_D_B={:.3f}|time/batch = {:.3f}"
            .format(opt.id, train_precentage, iteration, epoch, train_loss,
                    train_loss_G, train_loss_D_A, train_loss_D_B, end - start))
        torch.cuda.synchronize()

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_G', train_loss_G,
                              iteration)
            add_summary_value(tb_summary_writer, 'train_loss_D_A',
                              train_loss_D_A, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_D_B',
                              train_loss_D_B, iteration)

            # add hype parameters
            add_summary_value(tb_summary_writer, 'beam_size', opt.beam_size,
                              iteration)
            add_summary_value(tb_summary_writer, 'lambdaA', opt.lambda_A,
                              iteration)
            add_summary_value(tb_summary_writer, 'lambdaB', opt.lambda_B,
                              iteration)
            add_summary_value(tb_summary_writer, 'pool_size', opt.pool_size,
                              iteration)
            add_summary_value(tb_summary_writer, 'gan_type', opt.gan_type,
                              iteration)
            add_summary_value(tb_summary_writer, 'gan_d_type', opt.gan_d_type,
                              iteration)
            add_summary_value(tb_summary_writer, 'gan_g_type', opt.gan_g_type,
                              iteration)

        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            val_loss = eval_utils_gan.eval_split_gan(opt, model, netG_A_obj,
                                                     netG_A_rel, netG_A_atr,
                                                     loader, loader_i2t)
            val_loss = val_loss.item()
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            current_score = -val_loss
            best_flag = False

            save_id = iteration / opt.save_checkpoint_every
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True

            checkpoint_path = os.path.join(opt.checkpoint_path, 'model_D.pth')
            torch.save(
                {
                    'epoch': epoch,
                    'netD_A_atr': netD_A_atr.state_dict(),
                    'netD_A_obj': netD_A_obj.state_dict(),
                    'netD_A_rel': netD_A_rel.state_dict(),
                    'netD_B_atr': netD_B_atr.state_dict(),
                    'netD_B_obj': netD_B_obj.state_dict(),
                    'netD_B_rel': netD_B_rel.state_dict()
                }, checkpoint_path)

            checkpoint_path = os.path.join(opt.checkpoint_path, 'model_G.pth')
            torch.save(
                {
                    'epoch': epoch,
                    'netG_A_atr': netG_A_atr.state_dict(),
                    'netG_A_obj': netG_A_obj.state_dict(),
                    'netG_A_rel': netG_A_rel.state_dict(),
                    'netG_B_atr': netG_B_atr.state_dict(),
                    'netG_B_obj': netG_B_obj.state_dict(),
                    'netG_B_rel': netG_B_rel.state_dict()
                }, checkpoint_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()

            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                      'wb') as f:
                cPickle.dump(infos, f)
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                      'wb') as f:
                cPickle.dump(histories, f)

            if best_flag:
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model_D-best.pth')
                torch.save(
                    {
                        'epoch': epoch,
                        'netD_A_atr': netD_A_atr.state_dict(),
                        'netD_A_obj': netD_A_obj.state_dict(),
                        'netD_A_rel': netD_A_rel.state_dict(),
                        'netD_B_atr': netD_B_atr.state_dict(),
                        'netD_B_obj': netD_B_obj.state_dict(),
                        'netD_B_rel': netD_B_rel.state_dict()
                    }, checkpoint_path)

                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model_G-best.pth')
                torch.save(
                    {
                        'epoch': epoch,
                        'netG_A_atr': netG_A_atr.state_dict(),
                        'netG_A_obj': netG_A_obj.state_dict(),
                        'netG_A_rel': netG_A_rel.state_dict(),
                        'netG_B_atr': netG_B_atr.state_dict(),
                        'netG_B_obj': netG_B_obj.state_dict(),
                        'netG_B_rel': netG_B_rel.state_dict()
                    }, checkpoint_path)

                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, 'infos-best.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            # current_lr = update_learning_rate(schedulers, optimizers)
            epoch += 1
            update_lr_flag = True
            # make evaluation on validation set, and save model
            # lang_stats_isg = eval_utils_gan.eval_split_i2t(opt, model, netG_A_obj, netG_A_rel, netG_A_atr, loader, loader_i2t)
            lang_stats_isg = eval_utils_gan.eval_split_g2t(
                opt, model, netG_A_obj, netG_A_rel, netG_A_atr, loader,
                loader_i2t)

            if lang_stats_isg is not None:
                for k, v in lang_stats_isg.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #19
0
def train(opt):
    set_seed(opt.seed)
    save_folder = build_floder(opt)
    logger = create_logger(save_folder, 'train.log')
    tf_writer = SummaryWriter(os.path.join(save_folder, 'tf_summary'))

    if not opt.start_from:
        backup_envir(save_folder)
        logger.info('backup evironment completed !')

    saved_info = {'best': {}, 'last': {}, 'history': {}, 'eval_history': {}}

    # continue training
    if opt.start_from:
        opt.pretrain = False
        infos_path = os.path.join(save_folder, 'info.json')
        with open(infos_path) as f:
            logger.info('Load info from {}'.format(infos_path))
            saved_info = json.load(f)
            prev_opt = saved_info[opt.start_from_mode[:4]]['opt']

            exclude_opt = ['start_from', 'start_from_mode', 'pretrain']
            for opt_name in prev_opt.keys():
                if opt_name not in exclude_opt:
                    vars(opt).update({opt_name: prev_opt.get(opt_name)})
                if prev_opt.get(opt_name) != vars(opt).get(opt_name):
                    logger.info('Change opt {} : {} --> {}'.format(
                        opt_name, prev_opt.get(opt_name),
                        vars(opt).get(opt_name)))
        opt.feature_dim = opt.raw_feature_dim

    train_dataset = PropSeqDataset(opt.train_caption_file,
                                   opt.visual_feature_folder, True,
                                   opt.train_proposal_type, logger, opt)

    val_dataset = PropSeqDataset(opt.val_caption_file,
                                 opt.visual_feature_folder, False, 'gt',
                                 logger, opt)

    train_loader = DataLoader(train_dataset,
                              batch_size=opt.batch_size,
                              shuffle=True,
                              num_workers=opt.nthreads,
                              collate_fn=collate_fn)

    val_loader = DataLoader(val_dataset,
                            batch_size=opt.batch_size,
                            shuffle=False,
                            num_workers=opt.nthreads,
                            collate_fn=collate_fn)

    epoch = saved_info[opt.start_from_mode[:4]].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode[:4]].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode[:4]].get(
        'best_val_score', -1e5)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    opt.current_lr = vars(opt).get('current_lr', opt.lr)

    # Build model
    model = EncoderDecoder(opt)
    model.train()

    # Recover the parameters
    if opt.start_from and (not opt.pretrain):
        if opt.start_from_mode == 'best':
            model_pth = torch.load(
                os.path.join(save_folder, 'model-best-CE.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(os.path.join(save_folder, 'model-last.pth'))
        logger.info('Loading pth from {}, iteration:{}'.format(
            save_folder, iteration))
        model.load_state_dict(model_pth['model'])

    # Load the pre-trained model
    if opt.pretrain and (not opt.start_from):
        logger.info('Load pre-trained parameters from {}'.format(
            opt.pretrain_path))
        if torch.cuda.is_available():
            model_pth = torch.load(opt.pretrain_path)
        else:
            model_pth = torch.load(opt.pretrain_path,
                                   map_location=torch.device('cpu'))
        model.load_state_dict(model_pth['model'])

    if torch.cuda.is_available():
        model.cuda()

    if opt.optimizer_type == 'adam':
        optimizer = optim.Adam(model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)

    if opt.start_from:
        optimizer.load_state_dict(model_pth['optimizer'])

    # print the args for debugging
    print_opt(opt, model, logger)
    print_alert_message('Strat training !', logger)

    loss_sum = np.zeros(3)
    bad_video_num = 0
    start = time.time()

    # Epoch-level iteration
    while True:
        if True:
            # lr decay
            if epoch > opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            utils.set_lr(optimizer, opt.current_lr)

            # scheduled sampling rate update
            if epoch > opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(
                    opt.basic_ss_prob +
                    opt.scheduled_sampling_increase_prob * frac,
                    opt.scheduled_sampling_max_prob)
                model.decoder.ss_prob = opt.ss_prob

        # Batch-level iteration
        for dt in tqdm(train_loader):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            if opt.debug:
                # each epoch contains less mini-batches for debugging
                if (iteration + 1) % 5 == 0:
                    iteration += 1
                    break
            elif epoch == 0:
                break
            iteration += 1

            if torch.cuda.is_available():
                optimizer.zero_grad()
                dt = {
                    key: _.cuda() if isinstance(_, torch.Tensor) else _
                    for key, _ in dt.items()
                }

            dt = collections.defaultdict(lambda: None, dt)

            if True:
                train_mode = 'train'

                loss = model(dt, mode=train_mode)
                loss_sum[0] = loss_sum[0] + loss.item()

                loss.backward()
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                if torch.cuda.is_available():
                    torch.cuda.synchronize()

            losses_log_every = int(len(train_loader) / 5)

            if iteration % losses_log_every == 0:
                end = time.time()
                losses = np.round(loss_sum / losses_log_every, 3)
                logger.info(
                    "ID {} iter {} (epoch {}, lr {}), avg_iter_loss = {}, time/iter = {:.3f}, bad_vid = {:.3f}"
                    .format(opt.id, iteration, epoch, opt.current_lr, losses,
                            (end - start) / losses_log_every, bad_video_num))

                tf_writer.add_scalar('lr', opt.current_lr, iteration)
                tf_writer.add_scalar('ss_prob', model.decoder.ss_prob,
                                     iteration)
                tf_writer.add_scalar('train_caption_loss', losses[0].item(),
                                     iteration)

                loss_history[iteration] = losses.tolist()
                lr_history[iteration] = opt.current_lr
                loss_sum = 0 * loss_sum
                start = time.time()
                bad_video_num = 0
                torch.cuda.empty_cache()

        # evaluation
        if (epoch % opt.save_checkpoint_every
                == 0) and (epoch >= opt.min_epoch_when_save) and (epoch != 0):
            model.eval()

            result_json_path = os.path.join(
                save_folder, 'prediction',
                'num{}_epoch{}_score{}_nms{}_top{}.json'.format(
                    len(val_dataset), epoch, opt.eval_score_threshold,
                    opt.eval_nms_threshold, opt.eval_top_n))
            eval_score = evaluate(model,
                                  val_loader,
                                  result_json_path,
                                  opt.eval_score_threshold,
                                  opt.eval_nms_threshold,
                                  opt.eval_top_n,
                                  False,
                                  1,
                                  logger=logger)
            current_score = np.array(eval_score['f1']).mean()

            # add to tf summary
            for key in eval_score.keys():
                tf_writer.add_scalar(key,
                                     np.array(eval_score[key]).mean(),
                                     iteration)
            _ = [
                item.append(np.array(item).mean())
                for item in eval_score.values() if isinstance(item, list)
            ]
            print_info = '\n'.join([
                key + ":" + str(eval_score[key]) for key in eval_score.keys()
            ])
            logger.info(
                '\nValidation results of iter {}:\n'.format(iteration) +
                print_info)
            val_result_history[epoch] = {'eval_score': eval_score}

            # Save model
            saved_pth = {
                'epoch': epoch,
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
            }

            if opt.save_all_checkpoint:
                checkpoint_path = os.path.join(
                    save_folder, 'model_iter_{}.pth'.format(iteration))
            else:
                checkpoint_path = os.path.join(save_folder, 'model-last.pth')

            torch.save(saved_pth, checkpoint_path)
            logger.info('Save model at iter {} to {}.'.format(
                iteration, checkpoint_path))

            # save the model parameter and  of best epoch
            if current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                saved_info['best'] = {
                    'opt': vars(opt),
                    'iter': iteration,
                    'epoch': best_epoch,
                    'best_val_score': best_val_score,
                    'result_json_path': result_json_path,
                    'avg_proposal_num': eval_score['avg_proposal_number'],
                    'Precision': eval_score['Precision'],
                    'Recall': eval_score['Recall']
                }

                # suffix = "RL" if sc_flag else "CE"
                torch.save(saved_pth,
                           os.path.join(save_folder, 'model-best.pth'))
                logger.info(
                    'Save Best-model at iter {} to checkpoint file.'.format(
                        iteration))

            saved_info['last'] = {
                'opt': vars(opt),
                'iter': iteration,
                'epoch': epoch,
                'best_val_score': best_val_score,
            }
            saved_info['history'] = {
                'val_result_history': val_result_history,
                'loss_history': loss_history,
                'lr_history': lr_history,
            }
            with open(os.path.join(save_folder, 'info.json'), 'w') as f:
                json.dump(saved_info, f)
            logger.info('Save info to info.json')

            model.train()

        epoch += 1
        torch.cuda.empty_cache()
        # Stop criterion
        if epoch >= opt.epoch:
            tf_writer.close()
            break

    return saved_info
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        data_time = time.time() - start

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if iteration % opt.print_freq == 0:
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start, data_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, batch time = {:.3f}, data time = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), end - start, data_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path) # MODIFIED (ADDED)

            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best-i{}-score{}.pth'.format(iteration, best_val_score))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #21
0
def train(opt):

    ################################
    # Build dataloader
    ################################
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')):
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'), 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(opt, checkme), "Command line argument and saved model disagree on '%s' " % checkme
    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
        with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'), 'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(torch.load(os.path.join(opt.start_from, 'model.pth')))
    
    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model, factor=opt.noamopt_factor, warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {split: {'index_list': infos['split_ix'][split], 'iter_counter': infos['iterators'][split]} for split in ['train', 'val', 'test']}
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate  ** frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False
                
                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False
                    
            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp
            
            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks, att_masks, data['gts'], torch.arange(0, len(data['gts'])), sc_flag, struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(), model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss, iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr, iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward', model_out['reward'].mean(), iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar('lm_loss', model_out['lm_loss'].mean().item(), iteration)
                    tb_summary_writer.add_scalar('struc_loss', model_out['struc_loss'].mean().item(), iteration)
                    tb_summary_writer.add_scalar('reward', model_out['reward'].mean().item(), iteration)

                histories['loss_history'][iteration] = train_loss if not sc_flag else model_out['reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()
            
            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val',
                                'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = - val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(opt, model, infos, optimizer, append=str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt, model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Beispiel #22
0
    elif opt.optim == 'adamax':
        optimizer = optim.Adamax(params)

    # if opt.cnn_optim == 'sgd':
    #     cnn_optimizer = optim.SGD(cnn_params, momentum=0.9)
    # else:
    #     cnn_optimizer = optim.Adam(cnn_params)
    # load optimizer
    # learning_rate_list = np.linspace(opt.learning_rate, 0.0005, opt.max_epochs)

    for epoch in range(start_epoch, opt.max_epochs):
        if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
            if (epoch - opt.learning_rate_decay_start
                ) % opt.learning_rate_decay_every == 0:
                # decay the learning rate.
                utils.set_lr(optimizer, opt.learning_rate_decay_rate)
                opt.learning_rate = opt.learning_rate * opt.learning_rate_decay_rate

        if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
            frac = (epoch - opt.scheduled_sampling_start
                    ) // opt.scheduled_sampling_increase_every
            opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                              opt.scheduled_sampling_max_prob)
            model.ss_prob = opt.ss_prob

        if not opt.inference_only:
            train(epoch, opt)

        if epoch % opt.val_every_epoch == 0:
            lang_stats = eval(opt)
            # Save model if is improving on validation result
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}

    if opt.start_from is not None:
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    encoder = MemoryAugmentedEncoder(3, 0, attention_module=ScaledDotProductAttentionMemory,
                                     attention_module_kwargs={'m': 40})
    decoder = MeshedDecoder(8668, 180, 3, 0)
    models = Transformer(8667, encoder, decoder)
    # Create model
    model = models.cuda()
    lang_model = Seq2Seq().cuda()
    model.load_state_dict(torch.load('log_meshed/all2model20000.pth'))
    lang_model.load_state_dict(torch.load('language_model/langmodel06000.pth'))
    optimizer = utils.build_optimizer_adam(list(models.parameters())+ list(lang_model.parameters()), opt)
    update_lr_flag = True


    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['dist'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, attmasks = tmp
        labels = labels.long()

        captions = utils.decode_sequence(loader.get_vocab(), labels.view(fc_feats.size(0), -1), None)
        captions_all = []
        for index, caption in enumerate(captions):
            caption = caption.replace('<start>', '').replace(' ,', '').replace('  ', ' ')
            captions_all.append(caption)

        nd_labels = labels
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        d_steps = 1
        g_steps = 1
        beta = 0.2
        #print (orch.sum(labels!=0), torch.sum(masks!=0))
        if 1:
          if 1: 
              model.train()
              optimizer.zero_grad()
              wordact, _ = model(att_feats, labels.view(batchsize, -1))
              wordact_t = wordact[:,:-1,:]
              wordact_t = wordact_t.contiguous().view(wordact_t.size(0) * wordact_t.size(1), -1)
              labels_flat = labels.view(batchsize,-1)
              wordclass_v = labels_flat[:, 1:]
              wordclass_t = wordclass_v.contiguous().view(\
               wordclass_v.size(0) * wordclass_v.size(1), -1)
              loss_xe = F.cross_entropy(wordact_t[ ...], \
               wordclass_t[...].contiguous().view(-1))

              with torch.no_grad():
                  outcap, sampled_ids, sample_logprobs, x_all_langauge, outputs, log_probs_all = lang_model.sample(labels.view(batchsize, -1).transpose(1,0), att_feats.transpose(1,0), loader.get_vocab())
              logprobs_input, _ = model(att_feats, sampled_ids.cuda().long())
              log_probs = F.log_softmax(logprobs_input[:,:,:], 2)
              sample_logprobs_true = log_probs.gather(2, sampled_ids[:,:].cuda().long().unsqueeze(2))
              with torch.no_grad():
                  reward, cider_sample, cider_greedy, caps_sample, caps = get_self_critical_reward(batchsize, lang_model, labels.view(batchsize, -1).transpose(1,0), att_feats.transpose(1,0), outcap, captions_all, loader, 180)
                  reward = torch.tensor(reward)
                  kl_div = F.kl_div(log_probs.squeeze().cuda().detach(), torch.exp(log_probs_all.transpose(1,0)).cuda().detach(), reduce= False)
                  ratio_no = sample_logprobs_true.squeeze().cpu().double()
                  ratio_de = sample_logprobs.cpu().double()
                  ratio_no_f = torch.exp(ratio_no)
                  ratio_de_f = torch.exp(ratio_de)
                  ratio = (ratio_no_f/((1-beta)*ratio_de_f+ beta*ratio_no_f))
                  ratio = torch.clamp(ratio, min = 0.96)
                  ratio_prod = ratio.prod(1)
                  reward = (torch.tensor(reward).cuda()) - 0.05 * kl_div.mean()
              loss_rl1 = rl_crit(ratio_prod.cuda().unsqueeze(1).detach()*sample_logprobs_true.squeeze()[:,:-1], sampled_ids[:,1:].cpu(), reward.float().cuda().detach())
              #writer.add_scalar('RL loss', loss_rl1 , iteration)
              #writer.add_scalar('TRIS ratio', ratio.mean(), iteration)
              #writer.add_scalar('XE_loss', loss_xe, iteration)
              #writer.add_scalar('KL_div', kl_div.mean(), iteration)
              lamb = 0.5
              train_loss = lamb * loss_rl1 + (1 - lamb)* loss_xe
              train_loss.backward()
              optimizer.step()
          
          if 1:
            if iteration % opt.print_freq == 1:
              print('Read data:', time.time() - start)
              if not sc_flag:
                  print (ratio.mean())
                  print (reward.mean())
                  print (kl_div.mean())
                  print("iter {} (epoch {}), train_loss = {:.4f}, xe_loss = {:.3f}, train_time = {:.3f}" \
                    .format(iteration, epoch, train_loss.item(), loss_xe, data_time))
              else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

          # Update the iteration and epoch
          iteration += 1
          if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

          # Write the training loss summary
          if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model 
          if (iteration % opt.save_checkpoint_every == 0):
            checkpoint_path = os.path.join(opt.checkpoint_path, 'all2model{:05d}.pth'.format(iteration))
            torch.save(model.state_dict(), checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path, 'lang_model{:05d}.pth'.format(iteration))
            torch.save(lang_model.state_dict(), checkpoint_path)

            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)
            # Evaluate model
        #if 0:
            eval_kwargs = {'split': 'test',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            crit = utils.LanguageModelCriterion()                               
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)
            # Write validation result into summary
            #add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            #if lang_stats is not None:
            #    for k,v in lang_stats.items():
            #        add_summary_value(tb_summary_writer, k, v, iteration)
            #val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Our metric is CIDEr if available, otherwise validation loss
            #if opt.language_eval == 1:
            current_score = lang_stats['CIDEr']
         #   else:
         #       current_score = - val_loss

            # Save model in checkpoint path 
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_fc = utils.if_use_fc(opt.caption_model)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = load_info(opt)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Define and load model, optimizer, critics
    decoder = setup(opt).train().cuda()
    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing).cuda()
    else:
        crit = utils.LanguageModelCriterion().cuda()
    # crit = utils.LanguageModelCriterion().cuda()
    rl_crit = utils.RewardCriterion().cuda()
    if opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(decoder.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(decoder.parameters(), opt)
    # optimizer = utils.build_optimizer(decoder.parameters(), opt)
    models = {'decoder': decoder}
    optimizers = {'decoder': optimizer}
    save_nets_structure(models, opt)
    load_checkpoint(models, optimizers, opt)
    print('opt', opt)

    epoch_done = True
    sc_flag = False
    while True:
        if epoch_done:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                decoder.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False
            epoch_done = False

        # 1. fetch a batch of data from train split
        data = loader.get_batch('train')
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['tags'],
            data['masks'], data['att_masks'], data['verbs']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, tags, masks, att_masks, weak_relas = tmp
        vrg_data = {key: data['vrg_data'][key] if data['vrg_data'][key] is None \
            else torch.from_numpy(data['vrg_data'][key]).cuda() for key in data['vrg_data']}

        # 2. Forward model and compute loss
        torch.cuda.synchronize()
        optimizer.zero_grad()
        if not sc_flag:
            out = decoder(vrg_data, fc_feats, att_feats, labels, weak_relas,
                          att_masks)
            loss_words = crit(out[0], labels[:, 1:], masks[:, 1:])
            loss_tags = crit(out[1], tags[:, 1:], masks[:, 1:])
            loss = loss_words + loss_tags * 0.15
        else:
            gen_result, sample_logprobs, core_args = decoder(
                vrg_data,
                fc_feats,
                att_feats,
                weak_relas,
                att_masks,
                opt={
                    'sample_max': 0,
                    'return_core_args': True
                },
                mode='sample')
            reward = get_self_critical_reward(decoder, core_args, vrg_data,
                                              fc_feats, att_feats, weak_relas,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # 3. Update model
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Update the iteration and epoch
        iteration += 1
        # Write the training loss summary
        if (iteration % opt.log_loss_every == 0):
            # logging log
            logger.info("{} ({}), loss: {:.3f}".format(iteration, epoch,
                                                       train_loss))
            tb.add_values('loss', {'train': train_loss}, iteration)

        if data['bounds']['wrapped']:
            epoch += 1
            epoch_done = True

        # Make evaluation and save checkpoint
        if (opt.save_checkpoint_every > 0
                and iteration % opt.save_checkpoint_every
                == 0) or (opt.save_checkpoint_every == -1 and epoch_done):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'expand_features': False
            }
            eval_kwargs.update(vars(opt))
            predictions, lang_stats = eval_utils.eval_split(
                decoder, loader, eval_kwargs)

            if opt.reduce_on_plateau:
                assert 'CIDEr' in lang_stats, 'Error: cider should be in eval list'
                optimizer.scheduler_step(-lang_stats['CIDEr'])

            # log val results
            if not lang_stats is None:
                logger.info("Scores: {}".format(lang_stats))
                tb.add_values('scores', lang_stats, epoch)
            val_result_history[epoch] = {
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            current_score = 0 if lang_stats is None else lang_stats['CIDEr']
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            infos['val_result_history'] = val_result_history

            save_checkpoint(models, optimizers, infos, best_flag, opt)

        # Stop if reaching max epochs
        if epoch > opt.max_epochs and opt.max_epochs != -1:
            break
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    #ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = models.setup(opt).cuda()
    #pretrained_dict = torch.load(opt.model)
    #model.load_state_dict(pretrained_dict, strict=False)

    num_params = get_n_params(model)
    print('number of parameteres:', num_params)

    dp_model = torch.nn.DataParallel(model)
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                #opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                #model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['dist'],
            data['masks'], data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, dist_label, masks, att_masks = tmp
        batchsize = fc_feats.size(0)
        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            wordact, reconstruct = dp_model(fc_feats, att_feats, labels)
            #loss_dist = F.binary_cross_entropy(dist, dist_label.cpu().float())
            fc_feats_max, _ = att_feats.max(1)
            loss_rec = F.mse_loss(reconstruct.cpu(), fc_feats_max.cpu())
            mask = masks[:, 1:].contiguous()
            wordact = wordact[:, :, :-1]
            wordact_t = wordact.permute(0, 2, 1).contiguous()
            wordact_t = wordact_t.view(
                wordact_t.size(0) * wordact_t.size(1), -1)
            labels = labels.contiguous().view(-1, 6 * 30).cpu()
            wordclass_v = labels[:, 1:]
            wordclass_t = wordclass_v.contiguous().view(\
               wordclass_v.size(0) * wordclass_v.size(1), 1)
            maskids = torch.nonzero(mask.view(-1).cpu()).numpy().reshape(-1)
            loss_xe = F.cross_entropy(wordact_t[maskids, ...], \
               wordclass_t[maskids, ...].contiguous().view(maskids.shape[0]))
            loss = 5 * loss_xe + loss_rec
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            #add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            #ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration >= 60000 and iteration % opt.save_checkpoint_every == 0):
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)
            # Evaluate model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            #histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best.pth'
                infos_fname = 'model-best.pkl'
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #26
0
def train(opt):
    # tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)
    if not os.path.exists(opt.checkpoint_path):
        os.mkdir(opt.checkpoint_path)

    with open(os.path.join(opt.checkpoint_path,'config.json'),'w') as f:
        json.dump(vars(opt),f)

    # Load iterators
    loader = DataLoader(opt)
    dis_loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.activity_size = loader.activity_size
    opt.seq_length = loader.seq_length
    opt.video = 1

    # set up models
    gen, dis = models.setup(opt)
    gen_model = gen.cuda()
    gen_model.train()
    dis_model = dis.cuda()
    dis_model.train()
    gen_optimizer = utils.build_optimizer(gen_model.parameters(), opt)
    dis_optimizer = utils.build_optimizer(dis_model.parameters(), opt)

    # loss functions
    crit = utils.LanguageModelCriterion()
    gan_crit = nn.BCELoss().cuda()

    # keep track of iteration
    g_iter = 0
    g_epoch = 0
    d_iter = 0
    d_epoch = 0
    dis_flag = False
    update_lr_flag = True

    # Load from checkpoint path
    infos = {'opt': opt}
    histories = {}
    infos['vocab'] = loader.get_vocab()
    if opt.g_start_from is not None:
        # Open old infos and check if models are compatible
        with open(os.path.join(opt.g_start_from, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        # Load train/val histories
        with open(os.path.join(opt.g_start_from, 'histories.pkl')) as f:
            histories = cPickle.load(f)

        # Load generator
        g_start_epoch = opt.g_start_epoch
        g_model_path = os.path.join(opt.g_start_from, "gen_%s.pth" % g_start_epoch)
        g_optimizer_path = os.path.join(opt.g_start_from, "gen_optimizer_%s.pth" % g_start_epoch)
        assert os.path.isfile(g_model_path) and os.path.isfile(g_optimizer_path)
        gen_model.load_state_dict(torch.load(g_model_path))
        gen_optimizer.load_state_dict(torch.load(g_optimizer_path))
        if "latest" not in g_start_epoch and "best" != g_start_epoch:
            g_epoch = int(g_start_epoch) + 1
            g_iter = (g_epoch) * loader.split_size['train'] // opt.batch_size
        else:
            g_epoch = infos['g_epoch_' + g_start_epoch] + 1
            g_iter = infos['g_iter_' + g_start_epoch]
        print('loaded %s (epoch: %d iter: %d)' % (g_model_path, g_epoch, g_iter))

        # Load discriminator
        # assume that discriminator is loaded only if generator has been trained and saved in the same directory.
        if opt.d_start_from is not None:
            d_start_epoch = opt.d_start_epoch
            d_model_path = os.path.join(opt.d_start_from, "dis_%s.pth" % d_start_epoch)
            d_optimizer_path = os.path.join(opt.d_start_from, "dis_optimizer_%s.pth" % d_start_epoch)
            assert os.path.isfile(d_model_path) and os.path.isfile(d_optimizer_path)
            dis_model.load_state_dict(torch.load(d_model_path))
            dis_optimizer.load_state_dict(torch.load(d_optimizer_path))
            if "latest" not in d_start_epoch and "best" != d_start_epoch:
                d_epoch = int(d_start_epoch) + 1
                d_iter = (d_epoch) * loader.split_size['train'] // opt.batch_size
            else:
                d_epoch = infos['d_epoch_' + d_start_epoch] + 1
                d_iter = infos['d_iter_' + d_start_epoch]
            print('loaded %s (epoch: %d iter: %d)' % (d_model_path, d_epoch, d_iter))
    infos['opt'] = opt
    loader.iterators = infos.get('g_iterators', loader.iterators)
    dis_loader.iterators = infos.get('d_iterators', loader.iterators)

    # hybrid discriminator weight
    v_weight = opt.visual_weight
    l_weight = opt.lang_weight
    p_weight = opt.par_weight

    # misc
    best_val_score = infos.get('g_best_score', None)
    best_d_val_score = infos.get('d_best_score', None)
    opt.activity_size = loader.activity_size
    opt.seq_length = loader.seq_length
    opt.video = 1
    g_val_result_history = histories.get('g_val_result_history', {})
    d_val_result_history = histories.get('d_val_result_history', {})
    g_loss_history = histories.get('g_loss_history', {})
    d_loss_history = histories.get('d_loss_history', {})

    """ START TRAINING """
    while True:
        gc.collect()
        # set every epoch
        if update_lr_flag:
            # Assign the learning rate for generator
            if g_epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (g_epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(gen_optimizer, opt.current_lr)

            # Assign the learning rate for discriminator
            if dis_flag and d_epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (d_epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(dis_optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if g_epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (g_epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                gen.ss_prob = opt.ss_prob

            # Start using previous sentence as context for generator (default: 10 epoch)
            if opt.g_context_epoch >= 0 and g_epoch >= opt.g_context_epoch:
                gen_model.use_context()

            # Switch to training discriminator
            if opt.g_pre_nepoch >= 0 and g_epoch >= opt.g_pre_nepoch and not dis_flag:
                print('Switching to pre-training discrimiator...')
                loader.reset_iterator('train')
                dis_loader.reset_iterator('train')
                dis_flag = True

            update_lr_flag = False

        """ TRAIN GENERATOR """
        if not dis_flag:
            gen_model.train()

            # train generator
            start = time.time()
            gen_loss, wrapped, sent_num = train_generator(gen_model, gen_optimizer, crit, loader)
            end = time.time()

            # Print Info
            if g_iter % opt.losses_print_every == 0:
                print("g_iter {} (g_epoch {}), gen_loss = {:.3f}, time/batch = {:.3f}, num_sent = {} {}" \
                    .format(g_iter, g_epoch, gen_loss, end - start,sum(sent_num),sent_num))

            # Log Losses
            if g_iter % opt.losses_log_every == 0:
                g_loss = gen_loss
                g_loss_history[g_iter] = {'g_loss': g_loss, 'g_epoch': g_epoch}

            # Update the iteration
            g_iter += 1

            #########################
            # Evaluate & Save Model #
            #########################
            if wrapped:
                # evaluate model on dev set
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json,
                               'sample_max' : 1,
                               'language_eval': opt.language_eval,
                               'id' : opt.id,
                               'val_videos_use' : opt.val_videos_use,
                               'remove' : 1} # remove generated caption
                # eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats, _, _ = eval_split(gen_model, crit, loader, eval_kwargs=eval_kwargs)
                if opt.language_eval == 1:
                    current_score = lang_stats['METEOR']
                else:
                    current_score = - val_loss
                g_val_result_history[g_epoch] = {'g_loss': val_loss, 'g_score': current_score, 'lang_stats': lang_stats}

                # Save the best generator model
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'gen_best.pth')
                    torch.save(gen_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'gen_optimizer_best.pth'))
                    infos['g_epoch_best'] = g_epoch
                    infos['g_best_score'] = best_val_score
                    torch.save(gen_model.state_dict(), checkpoint_path)
                    print("best generator saved to {}".format(checkpoint_path))

                # Dump miscalleous informations and save
                infos['g_epoch_latest'] = g_epoch
                infos['g_iter_latest'] = g_iter
                infos['g_iterators'] = loader.iterators
                histories['g_val_result_history'] = g_val_result_history
                histories['g_loss_history'] = g_loss_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                # save the latest model
                if opt.save_checkpoint_every > 0 and g_epoch % opt.save_checkpoint_every == 0:
                    torch.save(gen.state_dict(), os.path.join(opt.checkpoint_path, 'gen_%d.pth'% g_epoch))
                    torch.save(gen.state_dict(), os.path.join(opt.checkpoint_path, 'gen_latest.pth'))
                    torch.save(gen_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'gen_optimizer_%d.pth'% g_epoch))
                    torch.save(gen_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'gen_optimizer_latest.pth'))
                    print("model saved to {} at epoch {}".format(opt.checkpoint_path, g_epoch))

                # update epoch and lr
                g_epoch += 1
                update_lr_flag = True

        """ TRAIN DISCRIMINATOR """
        if dis_flag:
            dis_model.train()
            gen_model.eval()
            # choose negatives to use for visual discriminator
            if d_epoch >= 2 and d_iter % 2 == 0:
                dis_loader.set_negatives('hard')
            else:
                dis_loader.set_negatives('random')

            # set temperature
            if opt.dynamic_temperature:
                temp_range = [1.0, 0.8, 0.6, 0.4, 0.2]
                temperature = temp_range[d_iter % (len(temp_range))]
            else:
                temperature = opt.train_temperature

            # train discriminator
            start = time.time()
            losses, accuracies, wrapped,sent_num = train_discriminator(dis_model,gen_model,dis_optimizer,gan_crit,dis_loader,
                                                                       temperature=temperature,gen_weight=opt.d_gen_weight,mm_weight=opt.d_mm_weight,
                                                                       use_vis=(v_weight >0), use_lang=(l_weight > 0), use_pair=(p_weight>0))
            dis_v_loss, dis_l_loss, dis_p_loss = losses
            end = time.time()

            # Print Info
            if d_iter % opt.losses_print_every == 0:
                print("d_iter {} (d_epoch {}), v_loss = {:.8f}, l_loss = {:.8f}, p_loss={:.8f}, time/batch = {:.3f}, num_sent = {} {}" \
                    .format(d_iter, d_epoch, dis_v_loss, dis_l_loss, dis_p_loss, end - start,sum(sent_num),sent_num))
                print("accuracies:", accuracies)

            # Log Losses
            if d_iter % opt.losses_log_every == 0:
                d_loss_history[d_iter] = {'dis_v_loss': dis_v_loss, 'dis_l_loss': dis_l_loss, 'dis_p_loss': dis_p_loss, 'd_epoch': d_epoch}
                for type, accuracy in accuracies.items():
                    d_loss_history[d_iter][type] = accuracy

            # Update the iteration
            d_iter += 1

            #########################
            # Evaluate & Save Model #
            #########################
            if wrapped:
                # evaluate model on dev set
                eval_kwargs = {'split': 'val',
                               'dataset': opt.input_json,
                               'sample_max' : (d_epoch+1) % 5 != 0,
                               'num_samples' : 30,
                               'temperature' : 0.2,
                               'language_eval' : opt.language_eval,
                               'id' : opt.id,
                               'val_videos_use': opt.val_videos_use,
                               'remove' : 1}
                _ , predictions, lang_stats, val_result, _ = eval_split(gen_model, crit, loader, dis_model, gan_crit,
                                                                        eval_kwargs=eval_kwargs)
                d_val_result_history[d_epoch] = val_result

                # save the best discriminator
                current_d_score = v_weight * (val_result['v_gen_accuracy'] + val_result['v_mm_accuracy']) + \
                                  l_weight  * (val_result['l_gen_accuracy'] + val_result['l_neg_accuracy']) + \
                                  p_weight * (val_result['p_gen_accuracy'] + val_result['p_neg_accuracy'])
                if best_d_val_score is None or current_d_score > best_d_val_score:
                    best_d_val_score = current_d_score
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'dis_best.pth')
                    torch.save(dis_model.state_dict(),checkpoint_path)
                    torch.save(dis_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'dis_optimizer_best.pth'))
                    infos['d_epoch_best'] = d_epoch
                    infos['d_iter_best'] = d_iter
                    infos['d_best_score'] = best_d_val_score
                    print("best discriminator saved to {}".format(checkpoint_path))

                # Dump miscalleous informations
                infos['d_epoch_latest'] = d_epoch
                infos['d_iter_latest'] = d_iter
                infos['d_iterators'] = dis_loader.iterators
                histories['d_loss_history'] = d_loss_history
                histories['d_val_result_history'] = d_val_result_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                # save model
                if opt.save_checkpoint_every > 0 and d_epoch % opt.save_checkpoint_every == 0:
                    torch.save(dis.state_dict(), os.path.join(opt.checkpoint_path, 'dis_%d.pth'% d_epoch))
                    torch.save(dis.state_dict(), os.path.join(opt.checkpoint_path, 'dis_latest.pth'))
                    torch.save(dis_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'dis_optimizer_%d.pth'% d_epoch))
                    torch.save(dis_optimizer.state_dict(), os.path.join(opt.checkpoint_path, 'dis_optimizer_latest.pth'))

                # update epoch and lr
                d_epoch += 1
                update_lr_flag = True
Beispiel #27
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    def ids_to_sents(ids):
        return utils.decode_sequence(loader, ids)

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    def load_infos(dir=opt.start_from, suffix=''):
        # open old infos and check if models are compatible
        with open(os.path.join(dir, 'infos_{}{}.pkl'.format(opt.id,
                                                            suffix))) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert getattr(saved_model_opt, checkme) == getattr(
                    opt, checkme
                ), "Command line argument and saved model disagree on '%s'" % checkme
        return infos

    def load_histories(dir=opt.start_from, suffix=''):
        path = os.path.join(dir, 'histories_{}{}.pkl'.format(opt.id, suffix))
        if os.path.isfile(path):
            with open(path) as f:
                histories = cPickle.load(f)
        return histories

    infos = {}
    histories = {}
    if opt.start_from is not None:
        infos = load_infos()
        histories = load_histories()

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit_ce = utils.LanguageModelCriterion()
    crit_mb = mBLEU(4)

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    def eval_model():
        model.eval()

        eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
        eval_kwargs.update(vars(opt))
        val_loss, predictions, lang_stats = eval_utils.eval_split(
            model, crit_ce, loader, eval_kwargs)

        # Write validation result into summary
        if tf is not None:
            add_summary_value(tf_summary_writer, 'validation loss', val_loss,
                              iteration)
            for k, v in lang_stats.items():
                add_summary_value(tf_summary_writer, k, v, iteration)
            tf_summary_writer.flush()

        model.train()

        return val_loss, predictions, lang_stats

    eval_model()

    opt.current_teach_mask_prefix_length = opt.teach_mask_prefix_length

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            # Assign the teach mask prefix length
            if epoch > opt.teach_mask_prefix_length_increase_start:
                frac = (epoch - opt.teach_mask_prefix_length_increase_start
                        ) // opt.teach_mask_prefix_length_increase_every
                opt.current_teach_mask_prefix_length = opt.teach_mask_prefix_length + frac * opt.teach_mask_prefix_length_increase_steps
            update_lr_flag = False

        verbose = (iteration % opt.verbose_iters == 0)

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        if iteration % opt.print_iters == 0:
            print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        teach_mask = utils.make_teach_mask(labels.size(1), opt)
        enable_ce = (opt.bleu_w != 1)
        enable_mb = (opt.bleu_w != 0)
        if enable_ce:
            enable_xe = (opt.xe_w != 0)
            enable_pg = (opt.pg_w != 0)
            if enable_xe:
                logits = model(
                    fc_feats,
                    att_feats,
                    labels,
                    teach_mask=(teach_mask if opt.teach_ce
                                and not opt.teach_all_input else None))
                if opt.teach_ce:
                    decode_length = logits.shape[1] + 1
                    teach_mask = teach_mask[:decode_length]
                    onehot = utils.to_onehot(labels[:, :decode_length],
                                             logits.shape[-1],
                                             dtype=torch.float)
                    probs = torch.exp(logits)
                    probs = torch.cat([onehot[:, :1], probs], 1)
                    probs = utils.mask_probs(probs, onehot, teach_mask)
                    if verbose:
                        verbose_probs = probs
                        verbose_probs.retain_grad()
                    logits = torch.log(1. - (1. - 1e-6) * (1. - probs))[:, 1:]
                loss_xe = crit_ce(logits, labels[:, 1:], masks[:, 1:])
            else:
                loss_xe = 0.
            if enable_pg:
                ids_sample, logprobs_sample = model.sample(
                    fc_feats, att_feats, opt={'sample_max': 0})
                ids_greedy, logprobs_greedy = model.sample(
                    fc_feats, att_feats, opt={'sample_max': 1})
                seq_sample = utils.tolist(ids_sample)
                seq_greedy = utils.tolist(ids_greedy)
                seq_target = utils.tolist(labels[:, 1:])
                rewards = [
                    sentence_bleu([t], s, smooth=True) -
                    sentence_bleu([t], g, smooth=True)
                    for s, g, t in zip(seq_sample, seq_greedy, seq_target)
                ]
                rewards = torch.tensor(rewards, device='cuda')
                mask_sample = torch.ne(ids_sample,
                                       torch.tensor(0, device='cuda')).float()
                loss_pg = (rewards *
                           (logprobs_sample * mask_sample).sum(1)).mean()
            else:
                loss_pg = 0.
            loss_ce = opt.xe_w * loss_xe + opt.pg_w * loss_pg
        else:
            loss_ce = 0.
        if enable_mb:
            logits = model(
                fc_feats,
                att_feats,
                labels,
                teach_mask=(teach_mask if not opt.teach_all_input else None))
            decode_length = logits.shape[1] + 1
            teach_mask = teach_mask[:decode_length]
            onehot = utils.to_onehot(labels[:, :decode_length],
                                     logits.shape[-1],
                                     dtype=torch.float)
            probs = torch.exp(logits)
            probs = torch.cat([onehot[:, :1], probs], 1)  # pad bos
            probs = utils.mask_probs(probs, onehot, teach_mask)
            if verbose:
                verbose_probs = probs
                verbose_probs.retain_grad()
            mask = masks[:, :decode_length]
            mask = torch.cat([mask[:, :1], mask], 1)
            loss_mb = crit_mb(probs,
                              labels[:, :decode_length],
                              mask,
                              min_fn=opt.min_fn,
                              min_c=opt.min_c,
                              verbose=verbose)
        else:
            loss_mb = 0.
        loss = loss_ce * (1 - opt.bleu_w) + loss_mb * opt.bleu_w
        loss.backward()
        utils.clip_gradient(
            optimizer,
            opt.grad_clip)  #TODO: examine clip method and record grad

        if verbose and 'verbose_probs' in locals():
            max_grads, max_ids = verbose_probs.grad.topk(opt.verbose_topk,
                                                         -1,
                                                         largest=False)
            max_probs = torch.gather(verbose_probs, -1, max_ids)
            max_sents = ids_to_sents(max_ids[:, :, 0])
            for sample_i in range(min(opt.samples, verbose_probs.shape[0])):
                l = len(max_sents[sample_i]) + 1
                print('max:\n{}'.format(max_sents[sample_i]))
                print('max probs:\n{}'.format(max_probs[sample_i][:l]))
                print('max grads:\n{}'.format(max_grads[sample_i][:l]))

        optimizer.step()
        train_loss = float(loss)
        torch.cuda.synchronize()
        end = time.time()
        if iteration % opt.print_iters == 0:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            val_loss, predictions, lang_stats = eval_model()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            if True:  # if true
                best_flag = False
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                def save_model(suffix=''):
                    model_path = os.path.join(opt.checkpoint_path,
                                              'model{}.pth'.format(suffix))
                    torch.save(model.state_dict(), model_path)
                    print("model saved to {}".format(model_path))
                    optimizer_path = os.path.join(
                        opt.checkpoint_path, 'optimizer{}.pth'.format(suffix))
                    torch.save(optimizer.state_dict(), optimizer_path)

                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_{}{}.pkl'.format(opt.id, suffix)),
                            'wb') as f:
                        cPickle.dump(infos, f)
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'histories_{}{}.pkl'.format(opt.id, suffix)),
                            'wb') as f:
                        cPickle.dump(histories, f)

                save_model()
                save_model(".iter{}".format(iteration))

                if best_flag:
                    save_model(".best")

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #28
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    ac = 0

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.checkpoint_path, 'infos_' + opt.id + format(int(opt.start_from),'04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from),'04') + '.pkl')):
            with open(os.path.join(opt.checkpoint_path, 'histories_' + opt.id + format(int(opt.start_from),'04') + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()

    #dp_model = torch.nn.DataParallel(model)
    #dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    for name, param in model.named_parameters():
        print(name)

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    CE_ac = utils.CE_ac()

    optim_para = model.parameters()
    optimizer = utils.build_optimizer(optim_para, opt)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from),'04')+'.pth')):
        optimizer.load_state_dict(torch.load(os.path.join(
            opt.checkpoint_path, 'optimizer' + opt.id + format(int(opt.start_from),'04')+'.pth')))

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1,1])
    sim_lambda = opt.sim_lambda
    reset_optimzer_index = 1
    while True:
        if opt.self_critical_after != -1 and epoch >= opt.self_critical_after and reset_optimzer_index :
            opt.learning_rate_decay_start = opt.self_critical_after
            opt.learning_rate_decay_rate = opt.learning_rate_decay_rate_rl
            opt.learning_rate = opt.learning_rate_rl
            reset_optimzer_index = 0


        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['labels'], data['masks'], data['mods']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        labels, masks, mods = tmp

        tmp = [data['att_feats'], data['att_masks'], data['attr_feats'], data['attr_masks'],data['rela_feats'], data['rela_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        att_feats, att_masks, attr_feats, attr_masks, rela_feats, rela_masks = tmp

        rs_data = {}
        rs_data['att_feats'] = att_feats
        rs_data['att_masks'] = att_masks
        rs_data['attr_feats'] = attr_feats
        rs_data['attr_masks'] = attr_masks
        rs_data['rela_feats'] = rela_feats
        rs_data['rela_masks'] = rela_masks

        if not sc_flag:
            logits, cw_logits = dp_model(rs_data, labels)
            ac = CE_ac(logits,labels[:,1:], masks[:,1:])
            print('ac :{0}'.format(ac))
            loss_lan = crit(logits,labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs, cw_logits = dp_model(rs_data,
                                                   opt={'sample_max':0}, mode='sample')
            reward = get_self_critical_reward(dp_model, rs_data, data, gen_result, opt)
            loss_lan = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda())

        loss_cw = crit(cw_logits, mods[:, 1:], masks[:, 1:])
        ac2 = CE_ac(cw_logits, mods[:, 1:], masks[:, 1:])
        print('ac :{0}'.format(ac2))
        if epoch < opt.step2_train_after:
            loss = loss_lan + sim_lambda*loss_cw
        else:
            loss = loss_lan

        accumulate_iter =  accumulate_iter + 1
        loss = loss/opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item()*opt.accumulate_number
            train_loss_lan = loss_lan.item()
            train_loss_cw = loss_cw.item()
            end = time.time()

            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            print('lr:{0}'.format(opt.current_lr))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (accumulate_iter % opt.accumulate_number == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_lan', train_loss_lan, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_cw', train_loss_cw, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            add_summary_value(tb_summary_writer, 'ac', ac, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0) and (accumulate_iter % opt.accumulate_number == 0):
            # eval model
            eval_kwargs = {'split': 'test',
                               'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            #val_loss, predictions, lang_stats = eval_utils_rs3.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            # add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            # if lang_stats is not None:
            #     for k,v in lang_stats.items():
            #         add_summary_value(tb_summary_writer, k, v, iteration)
            # val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            # if opt.language_eval == 1:
            #     current_score = lang_stats['CIDEr']
            # else:
            #     current_score = - val_loss
            current_score=0

            best_flag = False
            if True: # if true
                save_id = iteration/opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model'+opt.id+format(int(save_id),'04')+'.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer'+opt.id+format(int(save_id),'04')+'.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+format(int(save_id),'04')+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+format(int(save_id),'04')+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #29
0
            torch.load(os.path.join('models', 'model-1.ckpt')))

    current_lr = 1e-3
    optimizer = optim.Adam(model.parameters(), lr=current_lr)

    # criterion = nn.BCEWithLogitsLoss()
    criterion = FocalLoss()
    # criterion = nn.BCELoss()
    logger = Logger('./logs/')

    for epoch in range(2, opt.num_epoches):
        # schedule learning rate
        frac = epoch // 2
        decay_factor = 0.9**frac
        current_lr = current_lr * decay_factor
        utils.set_lr(optimizer, current_lr)

        # training
        model.train()
        start = time.time()

        for i, data in enumerate(train_loader):
            # prepare data and corresponding label(which is 'click')
            user_id = data['user_id'].cuda()
            hour = data['hour'].cuda()
            visual = data['visual'].cuda()
            click = data['click'].cuda()

            scale = data['scale'].cuda()
            gender = data['gender'].cuda().squeeze(1)
            age = data['age'].cuda().squeeze(1)
Beispiel #30
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model

    target_actor = models.setup(opt).cuda()

    ####################### Critic pretrain #####################################################################
    ##### Critic with state as input
    # if opt.critic_model == 'state_critic':
    #     critic_model = CriticModel(opt)
    # else:
    critic_model = AttCriticModel(opt)
    target_critic = AttCriticModel(opt)
    if vars(opt).get('start_from_critic', None) is not None and True:
        # check if all necessary files exist
        assert os.path.isdir(opt.start_from_critic
                             ), " %s must be a a path" % opt.start_from_critic
        print(
            os.path.join(opt.start_from_critic,
                         opt.critic_model + '_model.pth'))
        critic_model.load_state_dict(
            torch.load(
                os.path.join(opt.start_from_critic,
                             opt.critic_model + '_model.pth')))
        target_critic.load_state_dict(
            torch.load(
                os.path.join(opt.start_from_critic,
                             opt.critic_model + '_model.pth')))
    critic_model = critic_model.cuda()
    target_critic = target_critic.cuda()
    critic_optimizer = utils.build_optimizer(critic_model.parameters(), opt)
    dp_model.eval()
    critic_iter = 0
    init_scorer(opt.cached_tokens)
    critic_model.train()
    error_sum = 0
    loss_vector_sum = 0
    while opt.pretrain_critic == 1:
        if critic_iter > opt.pretrain_critic_steps:
            print('****************Finished critic training!')
            break
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        start = time.time()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        critic_model.train()
        critic_optimizer.zero_grad()
        assert opt.critic_model == 'att_critic_vocab'
        # crit_loss, reward, std = critic_loss_fun(fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data)
        crit_loss, reward, std = target_critic_loss_fun_mask(
            fc_feats, att_feats, att_masks, dp_model, critic_model, opt, data,
            target_critic, target_actor)
        crit_loss.backward()
        critic_optimizer.step()
        #TODO update target.
        for cp, tp in zip(critic_model.parameters(),
                          target_critic.parameters()):
            tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
        crit_train_loss = crit_loss.item()
        torch.cuda.synchronize
        end = time.time()
        error_sum += crit_train_loss**0.5 - std
        if (critic_iter % opt.losses_log_every == 0):
            print("iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f}, time/batch = {:.3f}" \
                .format(critic_iter, crit_train_loss**0.5, crit_train_loss**0.5-std, error_sum, end - start))
            print(opt.checkpoint_path)
            opt.importance_sampling = 1
            critic_model.eval()
            _, _, _, _ = get_rf_loss(dp_model,
                                     fc_feats,
                                     att_feats,
                                     att_masks,
                                     data,
                                     opt,
                                     loader,
                                     critic_model,
                                     test_critic=True)

        critic_iter += 1

        # make evaluation on validation set, and save model
        if (critic_iter % opt.save_checkpoint_every == 0):
            if not os.path.isdir(opt.checkpoint_path):
                os.mkdir(opt.checkpoint_path)
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           opt.critic_model + '_model.pth')
            torch.save(critic_model.state_dict(), checkpoint_path)

    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 5000:
            loader.reset_iterator('train')
            continue
        dp_model.train()
        critic_model.eval()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats,
                                                       att_feats,
                                                       att_masks,
                                                       opt={'sample_max': 0},
                                                       mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())
            elif opt.rl_type == 'arsm':
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks,
                                    data, opt, loader)
                #print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'rf4':
                loss, _, _, _ = get_rf_loss(dp_model, fc_feats, att_feats,
                                            att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'importance_sampling':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1],
                                   1)
                std = np.std(reward)
            elif opt.rl_type == 'importance_sampling_critic':
                opt.importance_sampling = 1
                loss, gen_result, reward, sample_logprobs_total = get_rf_loss(
                    target_actor, fc_feats, att_feats, att_masks, data, opt,
                    loader, target_critic)
                reward = np.repeat(reward[:, np.newaxis], gen_result.shape[1],
                                   1)
                std = np.std(reward)
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks,
                                   data, opt, loader)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(
                    sample_logprobs, gen_result.data,
                    torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader)
                reward = get_self_critical_reward(dp_model, fc_feats,
                                                  att_feats, att_masks, data,
                                                  gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data,
                               reward_cuda * arm_baseline)
            elif opt.rl_type == 'arsm_baseline_critic':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(
                    dp_model, fc_feats, att_feats, att_masks, data, opt,
                    loader, critic_model)
                reward, std = get_reward(data, gen_result, opt, critic=True)
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(
                    sample_logprobs, gen_result.data,
                    torch.from_numpy(reward).float().cuda() - arm_baseline)
            elif opt.rl_type == 'arsm_critic':
                #print(opt.critic_model)
                tic = time.time()
                loss = get_arm_loss(dp_model, fc_feats, att_feats, att_masks,
                                    data, opt, loader, critic_model)
                #print('arm_loss time', str(time.time()-tic))
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'critic_vocab_sum':
                assert opt.critic_model == 'att_critic_vocab'
                tic = time.time()
                gen_result, sample_logprobs_total = dp_model(
                    fc_feats,
                    att_feats,
                    att_masks,
                    opt={'sample_max': 0},
                    total_probs=True,
                    mode='sample')  #batch, seq, vocab
                #print('generation time', time.time()-tic)
                gen_result_pad = torch.cat([
                    gen_result.new_zeros(
                        gen_result.size(0), 1, dtype=torch.long), gen_result
                ], 1)
                tic = time.time()
                critic_value = critic_model(gen_result_pad, fc_feats,
                                            att_feats, True, opt,
                                            att_masks)  #batch, seq, vocab
                #print('critic time', time.time() - tic)
                probs = torch.sum(
                    F.softmax(sample_logprobs_total, 2) *
                    critic_value.detach(), 2)
                mask = (gen_result > 0).float()
                mask = torch.cat(
                    [mask.new(mask.size(0), 1).fill_(1), mask[:, :-1]], 1)
                loss = -torch.sum(probs * mask) / torch.sum(mask)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'reinforce_critic':
                #TODO change the critic to attention
                if opt.critic_model == 'state_critic':
                    critic_value, gen_result, sample_logprobs = critic_model(
                        dp_model, fc_feats, att_feats, opt, att_masks)
                    reward, std = get_reward(data,
                                             gen_result,
                                             opt,
                                             critic=True)
                    loss = rl_crit(
                        sample_logprobs, gen_result.data,
                        torch.from_numpy(reward).float().cuda() -
                        critic_value[:, :-1].data)
                elif opt.critic_model == 'att_critic':
                    gen_result, sample_logprobs = dp_model(
                        fc_feats,
                        att_feats,
                        att_masks,
                        opt={'sample_max': 0},
                        mode='sample')
                    gen_result_pad = torch.cat([
                        gen_result.new_zeros(gen_result.size(0),
                                             1,
                                             dtype=torch.long), gen_result
                    ], 1)
                    critic_value = critic_model(gen_result_pad, fc_feats,
                                                att_feats, True, opt,
                                                att_masks).squeeze(2)

                    reward, std = get_reward(data,
                                             gen_result,
                                             opt,
                                             critic=True)
                    loss = rl_crit(
                        sample_logprobs, gen_result.data,
                        torch.from_numpy(reward).float().cuda() -
                        critic_value.data)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(
                dp_model(fc_feats, att_feats, labels, att_masks),
                labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'best_embed.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        # with open(os.path.join(opt.checkpoint_path, 'best_logit.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.logit.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.9999 * first_order + 0.0001 * gradient
        second_order = 0.9999 * second_order + 0.0001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order -
                                        first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic
        if 'critic' in opt.rl_type:
            dp_model.eval()
            critic_model.train()
            utils.set_lr(critic_optimizer, opt.critic_learning_rate)
            critic_optimizer.zero_grad()
            assert opt.critic_model == 'att_critic_vocab'
            crit_loss, reward, std = target_critic_loss_fun_mask(
                fc_feats,
                att_feats,
                att_masks,
                dp_model,
                critic_model,
                opt,
                data,
                target_critic,
                target_actor,
                gen_result=gen_result,
                sample_logprobs_total=sample_logprobs_total,
                reward=reward)
            crit_loss.backward()
            critic_optimizer.step()
            for cp, tp in zip(critic_model.parameters(),
                              target_critic.parameters()):
                tp.data = tp.data + opt.gamma_critic * (cp.data - tp.data)
            for cp, tp in zip(dp_model.parameters(),
                              target_actor.parameters()):
                tp.data = tp.data + opt.gamma_actor * (cp.data - tp.data)
            crit_train_loss = crit_loss.item()
            error_sum += crit_train_loss**0.5 - std
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            elif 'critic' in opt.rl_type:
                print(
                    "iter {} , crit_train_loss = {:.3f}, difference = {:.3f}, difference_sum = {:.3f},variance = {:g}, time/batch = {:.3f}" \
                    .format(iteration, crit_train_loss ** 0.5, crit_train_loss ** 0.5 - std, error_sum, variance, end - start))
                print(opt.checkpoint_path)
                critic_model.eval()
                _, _, _, _ = get_rf_loss(dp_model,
                                         fc_feats,
                                         att_feats,
                                         att_masks,
                                         data,
                                         opt,
                                         loader,
                                         critic_model,
                                         test_critic=True)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance,
                                  iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward)
            critic_loss_history[
                iteration] = crit_train_loss if 'critic' in opt.rl_type else 0
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               opt.critic_model + '_model.pth')
                torch.save(critic_model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Beispiel #31
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model




    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    #TODO: change this to a flag
    crit = utils.LanguageModelCriterion_binary()
    rl_crit = utils.RewardCriterion_binary()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 10000:
            loader.reset_iterator('train')
            continue
        dp_model.train()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:], dp_model.depth,
                        dp_model.vocab2code, dp_model.phi_list, dp_model.cluster_size)
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'arm':
                loss = dp_model.get_arm_loss_binary_fast(fc_feats, att_feats, att_masks, opt, data, loader)
                #print(loss)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'rf4':
                loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.zeros([2,2])
            elif opt.rl_type =='mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data,
                                                                         opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'embeddings.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.999 * first_order + 0.001 * gradient
        second_order = 0.999 * second_order + 0.001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance, iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward)
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start


        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_binary.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            if lang_stats is not None:
                for k,v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break