示例#1
0
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None):
    model.train()
    model = nn.DataParallel(model)
    for epoch in range(opt["epochs"]):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for data in loader:
            torch.cuda.synchronize()
            fc_feats = Variable(data['fc_feats']).cuda()
            labels = Variable(data['labels']).long().cuda()
            masks = Variable(data['masks']).cuda()

            optimizer.zero_grad()
            if not sc_flag:
                seq_probs, _ = model(fc_feats, labels, 'train')
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(
                    seq_probs, seq_preds,
                    Variable(torch.from_numpy(reward).float().cuda()))

            loss.backward()
            utils.clip_gradient(optimizer, opt["grad_clip"])
            optimizer.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        if epoch != 0 and epoch % opt["save_checkpoint_every"] == 0:
            model_path = os.path.join(opt["checkpoint_path"],
                                      'model_%d.pth' % (epoch))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'model_score.txt')
            torch.save(model.state_dict(), model_path)
            print("model saved to %s" % (model_path))
            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
示例#2
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_cider_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                        masks[:, 1:])
        else:
            gen_result, sample_logprobs = model.sample(fc_feats, att_feats,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats, att_feats, data,
                                              gen_result)
            loss = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
def train(dataset,
          loader,
          model,
          rem,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None):
    writer = SummaryWriter('./runs/video_caption22')
    model.load_state_dict(
        torch.load(
            '/home/diml/video-caption.pytorch/save/RECON222_model_200.pth'))
    rem.load_state_dict(
        torch.load(
            '/home/diml/video-caption.pytorch/save/RECON222_module_200.pth'))
    #model.load_state_dict(torch.load('/home/diml/video-caption.pytorch/save/new_model_200.pth'))
    #model = nn.DataParallel(model)
    model.train()
    rem.train()

    vocab = dataset.get_vocab()

    for epoch in trange(opt["epochs"]):
        t_loss = [0, 0, 0]
        # =============================================================================
        #         model.eval()
        #         ev.demov(model,crit, dataset, dataset.get_vocab(),opt)
        # =============================================================================

        lr_scheduler.step()
        iteration = 0

        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for idx, data in enumerate(loader):
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].cuda()
            labels = data['labels'].cuda()
            labels2 = data['labels2'].cuda()
            masks2 = data['masks2'].cuda()
            masks = data['masks'].cuda()
            optimizer.zero_grad()
            if not sc_flag:
                seq_probs, seq_preds, hn, de_hn = model(
                    fc_feats, labels, 'train')
                loss_C = crit(seq_probs, labels[:, 1:], masks[:, 1:])
                fake_en_hn = rem(de_hn, seq_probs)
                f_seq_probs, f_seq_preds, hn, de_hn = model(fc_feats,
                                                            labels2,
                                                            'train',
                                                            h=fake_en_hn)
                loss_R = crit(f_seq_probs, labels2[:, 1:], masks2[:, 1:])
                loss = loss_R + loss_C
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(seq_probs, seq_preds,
                               torch.from_numpy(reward).float().cuda())

            t_loss[0] += loss.item()
            t_loss[1] += loss_C.item()
            t_loss[2] += loss_R.item()
            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            iteration += 1
            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch, np.mean(reward[:, 0])))
        writer.add_scalar('training total loss', t_loss[0] / 140, epoch + 200)
        writer.add_scalar('training Caption loss', t_loss[1] / 140,
                          epoch + 200)
        writer.add_scalar('training Reconstruction loss', t_loss[2] / 140,
                          epoch + 200)
        if epoch % opt["save_checkpoint_every"] == 0:

            model_path = os.path.join(opt["checkpoint_path"],
                                      'RECON222_model_%d.pth' % (epoch + 200))
            rem_path = os.path.join(opt["checkpoint_path"],
                                    'RECON222_module_%d.pth' % (epoch + 200))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'RECON222_model_score.txt')
            torch.save(model.state_dict(), model_path)
            torch.save(rem.state_dict(), rem_path)
            print("model saved to %s" % (model_path))

            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))

        with torch.no_grad():
            _, seq_preds, __, ___ = model(fc_feats, mode='inference', opt=opt)
            _, f_seq_preds, __, ___ = model(fc_feats,
                                            mode='inference',
                                            h=fake_en_hn,
                                            opt=opt)
            origin = utils.decode_sequence(vocab, seq_preds)[0]
            revision = utils.decode_sequence(vocab, f_seq_preds)[0]
            with open('./results/training_versus.txt', 'a') as f:
                f.write("epoch is %d \n" % epoch)
                origin = "origin caption: " + origin + "\n"
                revision = "revision caption: " + revision + "\n"
                f.write(origin)
                f.write(revision)
def train(dataset,
          loader,
          model,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None):
    writer = SummaryWriter('./runs/video_caption_basic')
    model.load_state_dict(
        torch.load('/home/diml/video-caption.pytorch/save/new_model_200.pth'))
    #model = nn.DataParallel(model)
    model.train()
    vocab = dataset.get_vocab()

    for epoch in trange(300):
        t_loss = 0
        # =============================================================================
        #         model.eval()
        #         ev.demov(model,crit, dataset, dataset.get_vocab(),opt)
        # =============================================================================

        lr_scheduler.step()
        iteration = 0

        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for idx, data in enumerate(loader):
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].cuda()
            labels = data['labels'].cuda()
            masks = data['masks'].cuda()
            optimizer.zero_grad()
            if not sc_flag:
                seq_probs, seq_preds, hn, de_hn = model(
                    fc_feats, labels, 'train')
                loss_C = crit(seq_probs, labels[:, 1:], masks[:, 1:])

                loss = loss_C
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(seq_probs, seq_preds,
                               torch.from_numpy(reward).float().cuda())

            t_loss += loss.item()
            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            iteration += 1
            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch + 201, np.mean(reward[:, 0])))
        writer.add_scalar('training total loss', t_loss / 140, epoch + 200)
        if epoch % opt["save_checkpoint_every"] == 0:

            model_path = os.path.join(opt["checkpoint_path"],
                                      'new_model_%d.pth' % (epoch + 200))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'Rnew_model_score.txt')
            torch.save(model.state_dict(), model_path)
            print("model saved to %s" % (model_path))

            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))

        with torch.no_grad():
            _, seq_preds, __, ___ = model(fc_feats, mode='inference', opt=opt)
            print(utils.decode_sequence(vocab, seq_preds)[0])
示例#5
0
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None):
    model.train()
    if opt['visdom']:
        viz = visdom.Visdom(env='train')
        loss_win = viz.line(np.arange(1), opts={'title': 'loss'})

    for epoch in range(opt["epochs"]):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        # print(opt["self_crit_after"])
        if opt["self_crit_after"] != -1 and epoch >= opt[
                "self_crit_after"]:  #每多少次保存一下
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        # print(model)

        for data in loader:
            # print(data)
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].cuda()
            # voice_feats = data['voice_feats'].cuda()
            if opt['with_hand'] == 1:
                hand_feats = data['hand_feats'].cuda()
                hand_pro = data['hand_pro'].cuda()
            labels = data['labels'].cuda()
            masks = data['masks'].cuda()
            #print(sc_flag)
            optimizer.zero_grad()
            if not sc_flag:
                # seq_probs, _ = model(fc_feats, voice_feats, hand_feats, labels, 'train')
                if opt['with_hand'] == 1:
                    seq_probs, _ = model(fc_feats, hand_feats, hand_pro,
                                         labels, 'train')
                else:
                    seq_probs, _ = model.forward2(fc_feats, labels, 'train')
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            # todo 下面else部分没有修改声音和手语的内容
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(seq_probs, seq_preds,
                               torch.from_numpy(reward).float().cuda())
            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("?iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
                if opt['visdom']:
                    viz.line(Y=np.array([train_loss]),
                             X=np.array([epoch]),
                             win=loss_win,
                             update='append')
            else:
                print("??iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        if epoch % opt["save_checkpoint_every"] == 0:
            model_path = os.path.join(opt["checkpoint_path"],
                                      'model_%d.pth' % (epoch))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'model_score.txt')
            torch.save(model.state_dict(), model_path)
            # print("model saved to %s" % (model_path))
            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
def train(train_loader,
          val_loader,
          model,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None):
    model.train()
    model = nn.DataParallel(model)
    # lowest val loss
    best_loss = None
    for epoch in range(opt.epochs):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        if opt.self_crit_after != -1 and epoch >= opt.self_crit_after:
            sc_flag = True
            init_cider_scorer(opt.cached_tokens)
        else:
            sc_flag = False

        for data in train_loader:
            torch.cuda.synchronize()
            fc_feats = Variable(data['fc_feats']).cuda()
            labels = Variable(data['labels']).long().cuda()
            masks = Variable(data['masks']).cuda()
            if not sc_flag:
                seq_probs, predicts = model(fc_feats, labels)
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                gen_result, sample_logprobs = model.sample(fc_feats, vars(opt))
                # print(gen_result)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  gen_result)
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(reward).float().cuda()))

            optimizer.zero_grad()
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.3f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        # lowest val loss

        if epoch % opt.save_checkpoint_every == 0:
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'model_%d.pth' % (epoch))
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to %s" % (checkpoint_path))
            val_loss = val(val_loader, model, crit)
            print("Val loss is: %.6f" % (val_loss))
            model.train()
            if best_loss is None or val_loss < best_loss:
                print("(epoch %d), now lowest val loss is %.6f" %
                      (epoch, val_loss))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model_best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("best model saved to %s" % (checkpoint_path))
                best_loss = val_loss
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_cider_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = model.sample(fc_feats, att_feats, {'sample_max':0})
            reward = get_self_critical_reward(model, fc_feats, att_feats, data, gen_result)
            loss = rl_crit(sample_logprobs, gen_result, Variable(torch.from_numpy(reward).float().cuda(), requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward', np.mean(reward[:,0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward[:,0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                if lang_stats is not None:
                    for k,v in lang_stats.items():
                        add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
示例#8
0
def train(loader, model, optimizer, lr_scheduler, opt, device, crit):
    # Add Tensorboard
    train_logger, valid_logger = None, None
    if opt["log_dir"] is not None:
        train_logger = tb.SummaryWriter(path.join(opt["log_dir"], 'train'),
                                        flush_secs=1)
        # valid_logger = tb.SummaryWriter(path.join(opt["log_dir"], 'valid'), flush_secs=1)

    if opt["model"] == 'S2VTACTModel':
        use_action = True
    else:
        use_action = False

    # Training Procedure
    model.train()
    global_step = 0
    # model = nn.DataParallel(model) # just ignore data parallel here
    for epoch in range(opt["epochs"]):
        iteration = 0
        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for data in loader:
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            fc_feats = data['fc_feats'].to(device)
            labels = data['labels'].to(device)
            masks = data['masks'].to(device)

            if use_action:
                action = data['action'].to(device)

            optimizer.zero_grad()
            if not sc_flag:
                if use_action:
                    seq_probs, _ = model(vid_feats=fc_feats,
                                         action=action,
                                         device=device,
                                         target_variable=labels,
                                         mode='train')
                else:
                    seq_probs, _ = model(vid_feats=fc_feats,
                                         target_variable=labels,
                                         mode='train')
                # Using Language Model Loss (NLLLoss or CrossEntropy Loss)
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                print('Currently ignore RL criterion')
                # seq_probs, seq_preds = model(
                #     fc_feats, mode='inference', opt=opt)
                # reward = get_self_critical_reward(model, fc_feats, data,
                #                                   seq_preds)
                # print(reward.shape)
                # loss = rl_crit(seq_probs, seq_preds,
                #                torch.from_numpy(reward).float().cuda())

            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            train_loss = loss.item()
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            iteration += 1
            global_step += 1

            if not sc_flag:
                if iteration % 20 == 0:
                    print("iter %d (epoch %d), train_loss = %.6f" %
                          (iteration, epoch, train_loss))
            else:
                print('Currently ignore RL criterion')
                # print("iter %d (epoch %d), avg_reward = %.6f" %
                #       (iteration, epoch, np.mean(reward[:, 0])))

            # Add Logger
            if train_logger is not None and global_step % 100 == 0:
                # Log some real data
                pass

            # Add Loss Statistics
            if train_logger is not None and iteration % 10 == 0:
                train_logger.add_scalar('loss',
                                        train_loss,
                                        global_step=global_step)

        # Step the Learning Rate Scheduler
        lr_scheduler.step()
        train_logger.add_scalar('lr',
                                optimizer.param_groups[0]['lr'],
                                global_step=global_step)

        if epoch % opt["save_checkpoint_every"] == 0:
            model_path = os.path.join(opt["checkpoint_path"],
                                      'model_%d.pth' % (epoch))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'model_score.txt')
            torch.save(model.state_dict(), model_path)
            print("model saved to %s" % (model_path))
            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
示例#9
0
def train(loader,
          model,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None,
          opt_test=None,
          test_dataset=None):
    model.train()
    loss_avg = averager()
    #model = nn.DataParallel(model)
    writer = SummaryWriter()
    for epoch in range(opt["epochs"]):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for data in loader:
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].cuda()
            labels = data['labels'].cuda()
            masks = data['masks'].cuda()

            # clip_nums = data['clip_num']
            # sorted_clip_nums, indices = torch.sort(clip_nums, descending=True)
            # _, desorted_indices = torch.sort(indices, descending=False)
            # fc_feats = fc_feats[indices]
            # pack = rnn.pack_padded_sequence(fc_feats, sorted_clip_nums, batch_first=True)
            optimizer.zero_grad()
            if not sc_flag:
                seq_probs, _ = model(fc_feats, labels, 'train')
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(seq_probs, seq_preds,
                               torch.from_numpy(reward).float().cuda())

            loss_avg.add(loss)
            loss.backward()
            clip_grad_value_(model.parameters(), opt['grad_clip'])
            optimizer.step()
            # train_loss = loss.item()
            torch.cuda.synchronize()
            iteration += 1

            # if not sc_flag:
            #     print("iter %d (epoch %d), train_loss = %.6f" %
            #           (iteration, epoch, train_loss))
            # else:
            #     print("iter %d (epoch %d), avg_reward = %.6f" %
            #           (iteration, epoch, np.mean(reward[:, 0])))
        print("[epoch %d]->train_loss = %.6f" % (epoch, loss_avg.val()))
        writer.add_scalar('scalar/train_loss_epcho', loss_avg.val())
        if epoch % opt["save_checkpoint_every"] == 0:
            test(model, crit, test_dataset, test_dataset.get_vocab(), opt_test,
                 writer)
            model.train()
            model_path = os.path.join(opt["checkpoint_path"],
                                      'model_%d.pth' % (epoch))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'model_score.txt')
            torch.save(model.state_dict(), model_path)
            print("model saved to %s" % (model_path))
            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, loss_avg.val()))
        loss_avg.reset()