Exemplo n.º 1
0
def train(train_loader, model, criterion, optimizer, epoch, opt):
    """
    train for one epoch on the training set
    """
    batch_time = utils.AverageMeter() 
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter() 

    # training mode
    model.train() 

    end = time.time() 
    for i, (input_points, labels) in enumerate(train_loader):
        # bz x 2048 x 3 
        input_points = Variable(input_points)
        input_points = input_points.transpose(2, 1) 
        labels = Variable(labels[:, 0])

        # print(points.size())
        # print(labels.size())
        # shift data to GPU
        if opt.cuda:
            input_points = input_points.cuda() 
            labels = labels.long().cuda() # must be long cuda tensor  
        
        # forward, backward optimize 
        output, _ = model(input_points)
        # debug_here() 
        loss = criterion(output, labels)
        ##############################
        # measure accuracy
        ##############################
        prec1 = utils.accuracy(output.data, labels.data, topk=(1,))[0]
        losses.update(loss.data[0], input_points.size(0))
        top1.update(prec1[0], input_points.size(0))

        ##############################
        # compute gradient and do sgd 
        ##############################
        optimizer.zero_grad() 
        loss.backward() 
        ##############################
        # gradient clip stuff 
        ##############################
        utils.clip_gradient(optimizer, opt.gradient_clip)
        
        optimizer.step() 

        # measure elapsed time
        batch_time.update(time.time() - end) 
        end = time.time() 
        if i % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
              'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
              'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
              'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                  epoch, i, len(train_loader), batch_time=batch_time,
                  loss=losses, top1=top1)) 
Exemplo n.º 2
0
def train(train_loader, model, criterion, optimizer, epoch, opt):
    """
    train for one epoch on the training set
    """
    # training mode
    model.train() 

    for i, (input_points, _labels, segs) in enumerate(train_loader):
        # bz x 2048 x 3 
        input_points = Variable(input_points)
        input_points = input_points.transpose(2, 1)
        ###############
        ##
        ###############
        _labels = _labels.long() 
        segs = segs.long() 
        labels_onehot = utils.labels_batch2one_hot_batch(_labels, opt.num_classes)
        labels_onehot = Variable(labels_onehot) # we dnonot calculate the gradients here
        # labels_onehot.requires_grad = True
        segs = Variable(segs) 

        if opt.cuda:
            input_points = input_points.cuda() 
            segs = segs.cuda() # must be long cuda tensor 
            labels_onehot = labels_onehot.float().cuda()  # this will be feed into the network
        
        optimizer.zero_grad()
        # forward, backward optimize 
        # pred, _ = model(input_points, labels_onehot)
        pred, _, _ = model(input_points, labels_onehot)
        pred = pred.view(-1, opt.num_seg_classes)
        segs = segs.view(-1, 1)[:, 0] 
        # debug_here() 
        loss = criterion(pred, segs) 
        loss.backward() 
        ##############################
        # gradient clip stuff 
        ##############################
        utils.clip_gradient(optimizer, opt.gradient_clip)
        optimizer.step() 
        pred_choice = pred.data.max(1)[1]
        correct = pred_choice.eq(segs.data).cpu().sum()

        if i % opt.print_freq == 0:
            print('[%d: %d] train loss: %f accuracy: %f' %(i, len(train_loader), loss.data[0], correct/float(opt.batch_size * opt.num_points)))
Exemplo n.º 3
0
def train(opt):
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    ss_prob_history = infos.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()
    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    cnn_optimizer = optim.Adam(cnn_model.parameters(),
                               lr=opt.cnn_learning_rate,
                               weight_decay=opt.cnn_weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer-cnn.pth')):
            cnn_optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            # Update the training stage of cnn
            if opt.finetune_cnn_after == -1 or epoch < opt.finetune_cnn_after:
                for p in cnn_model.parameters():
                    p.requires_grad = False
                cnn_model.eval()
            else:
                for p in cnn_model.parameters():
                    p.requires_grad = True
                cnn_model.train()
            update_lr_flag = False

        torch.cuda.synchronize()
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['images'], data['labels'], data['masks']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        images, labels, masks = tmp

        att_feats = cnn_model(images)
        fc_feats = att_feats.mean(2).mean(3).squeeze(2).squeeze(2)

        att_feats = att_feats.unsqueeze(1).expand(*((
            att_feats.size(0),
            opt.seq_per_img,
        ) + att_feats.size()[1:])).contiguous().view(
            *((att_feats.size(0) * opt.seq_per_img, ) + att_feats.size()[1:]))
        fc_feats = fc_feats.unsqueeze(1).expand(*((
            fc_feats.size(0),
            opt.seq_per_img,
        ) + fc_feats.size()[1:])).contiguous().view(
            *((fc_feats.size(0) * opt.seq_per_img, ) + fc_feats.size()[1:]))

        optimizer.zero_grad()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            cnn_optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            utils.clip_gradient(cnn_optimizer, opt.grad_clip)
            cnn_optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                cnn_model, model, crit, loader, eval_kwargs)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(model.state_dict(), checkpoint_path)
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                print("cnn model saved to {}".format(cnn_checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                  'optimizer-cnn.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['val_result_history'] = val_result_history
                infos['loss_history'] = loss_history
                infos['lr_history'] = lr_history
                infos['ss_prob_history'] = ss_prob_history
                infos['vocab'] = loader.get_vocab()
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    print("cnn model saved to {}".format(cnn_checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 4
0
def train(epoch, opt):
    model.train()

    #########################################################################################
    # Training begins here
    #########################################################################################
    data_iter = iter(dataloader)
    lm_loss_temp = 0
    bn_loss_temp = 0
    fg_loss_temp = 0
    cider_temp = 0
    rl_loss_temp = 0
    start = time.time()
    #mycount = 0
    #mybatch = 5
    #loss = 0
    for step in range(len(dataloader)-1):
        data = data_iter.next()
        img, iseq, gts_seq, num, proposals, bboxs, box_mask, img_id = data
        proposals = proposals[:,:max(int(max(num[:,1])),1),:]
        bboxs = bboxs[:,:int(max(num[:,2])),:]
        box_mask = box_mask[:,:,:max(int(max(num[:,2])),1),:]

        input_imgs.data.resize_(img.size()).copy_(img)
        input_seqs.data.resize_(iseq.size()).copy_(iseq)
        gt_seqs.data.resize_(gts_seq.size()).copy_(gts_seq)
        input_num.data.resize_(num.size()).copy_(num)
        input_ppls.data.resize_(proposals.size()).copy_(proposals)
        gt_bboxs.data.resize_(bboxs.size()).copy_(bboxs)
        mask_bboxs.data.resize_(box_mask.size()).copy_(box_mask)
        loss = 0
        #model.init_hidden()
        #if mycount == 0:
        #model.zero_grad()
        #mycount = mybatch

        #If using RL for self critical sequence training
        if opt.self_critical:
            rl_loss, bn_loss, fg_loss, cider_score = model(input_imgs, input_seqs, gt_seqs, input_num, input_ppls, gt_bboxs, mask_bboxs, 'RL')
            cider_temp += cider_score.sum().data[0] / cider_score.numel()
            loss += (rl_loss.sum() + bn_loss.sum() + fg_loss.sum()) / rl_loss.numel()
            rl_loss_temp += loss.data[0]

        #If using MLE
        else:
            lm_loss, bn_loss, fg_loss = model(input_imgs, input_seqs, gt_seqs, input_num, input_ppls, gt_bboxs, mask_bboxs, 'MLE')
            loss += ((lm_loss.sum() + bn_loss.sum() + fg_loss.sum()) / lm_loss.numel())

            lm_loss_temp += (lm_loss.sum().data.item() / lm_loss.numel())
            bn_loss_temp += (bn_loss.sum().data.item() / lm_loss.numel()) 
            fg_loss_temp += (fg_loss.sum().data.item() / lm_loss.numel())

        model.zero_grad()
        
        loss.backward()
        nn.utils.clip_grad_norm_(model.parameters(), opt.grad_clip)
        #utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        if opt.finetune_cnn:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()

        if step % opt.disp_interval == 0 and step != 0:
            end = time.time()
            lm_loss_temp /= opt.disp_interval
            bn_loss_temp /= opt.disp_interval
            fg_loss_temp /= opt.disp_interval
            rl_loss_temp /= opt.disp_interval
            cider_temp /= opt.disp_interval

            print("step {}/{} (epoch {}), lm_loss = {:.3f}, bn_loss = {:.3f}, fg_loss = {:.3f}, rl_loss = {:.3f}, cider_score = {:.3f}, lr = {:.5f}, time/batch = {:.3f}" \
                .format(step, len(dataloader), epoch, lm_loss_temp, bn_loss_temp, fg_loss_temp, rl_loss_temp, cider_temp, opt.learning_rate, end - start))
            
            start = time.time()

            lm_loss_temp = 0
            bn_loss_temp = 0
            fg_loss_temp = 0
            cider_temp = 0
            rl_loss_temp = 0

        # Write the training loss summary
        #if opt.self_critical:
        #    if (iteration % opt.losses_log_every == 0):
        #        if tf is not None:
        #            add_summary_value(tf_summary_writer, 'train_loss', loss, iteration)
        #            add_summary_value(tf_summary_writer, 'learning_rate', opt.learning_rate, iteration)
        #            # add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
        #            if opt.self_critical:
        #                add_summary_value(tf_summary_writer, 'cider_score', cider_score.data.item(), iteration)
        #        
        #            tf_summary_writer.flush()

        loss_history[iteration] = loss.data.item()
        lr_history[iteration] = opt.learning_rate
Exemplo n.º 5
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # cnn_model = utils.build_cnn(opt)
    cnn_model = create_extractor(
        "/root/PycharmProjects/vgg_vae_best_model.pth")
    cnn_model = cnn_model.cuda()

    if vars(opt).get('start_from', None) is not None:
        cnn_model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model-cnn.pth')))
        print("load cnn model parameters from {}".format(
            os.path.join(opt.start_from, 'model-cnn.pth')))

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)
    dp_lw_model = torch.nn.DataParallel(lw_model)
    # dp_lw_model = lw_model
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)

    # if opt.finetune_cnn_after != -1:
    #     # only finetune the layer2 to layer4
    cnn_optimizer = optim.Adam([{
        'params': module.parameters()
    } for module in cnn_model.finetune_modules],
                               lr=opt.cnn_learning_rate,
                               weight_decay=opt.cnn_weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        if os.path.isfile(os.path.join(opt.start_from, "optimizer.pth")):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1:
            if os.path.isfile(os.path.join(opt.start_from,
                                           'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(
                    torch.load(
                        os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    def save_checkpoint(model,
                        cnn_model,
                        infos,
                        optimizer,
                        cnn_optimizer,
                        histories=None,
                        append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))

        cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'model-cnn%s.pth' % (append))
        torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
        print("cnn model saved to {}".format(cnn_checkpoint_path))

        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)

        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer%s-cnn.pth' % (append))
            torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

        with open(
                os.path.join(opt.checkpoint_path,
                             'infos_' + opt.id + '%s.pkl' % (append)),
                'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    # set the decayed rate
                    utils.set_lr(optimizer, opt.current_lr)
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # Update the training stage of cnn
                if opt.finetune_cnn_after == -1 or epoch < opt.finetune_cnn_after:
                    for p in cnn_model.parameters():
                        p.requires_grad = False
                    cnn_model.eval()
                else:
                    for p in cnn_model.parameters():
                        p.requires_grad = True
                    # Fix the first few layers:
                    for module in cnn_model.fixed_modules:
                        for p in module.parameters():
                            p.requires_grad = False
                        cnn_model.train()

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            torch.cuda.synchronize()
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            # att_feats 8x672x224
            att_feats = att_feats.view(att_feats.size(0), 3, 224, 224)
            att_feats, fc_feats = cnn_model(att_feats)
            # fc_feats = att_feats.mean(3).mean(2)
            # att_feats = torch.nn.functional.adaptive_avg_pool2d(
            #     att_feats, [7, 7]).permute(0, 2, 3, 1)
            att_feats = att_feats.permute(0, 2, 3, 1)
            att_feats = att_feats.view(att_feats.size(0), 49, -1)

            att_feats = att_feats.unsqueeze(1).expand(*((
                att_feats.size(0),
                opt.seq_per_img,
            ) + att_feats.size()[1:])).contiguous().view(
                (att_feats.size(0) * opt.seq_per_img), -1,
                att_feats.size()[-1])
            fc_feats = fc_feats.unsqueeze(1).expand(*((
                fc_feats.size(0),
                opt.seq_per_img,
            ) + fc_feats.size()[1:])).contiguous().view(
                *((fc_feats.size(0) * opt.seq_per_img, ) +
                  fc_feats.size()[1:]))

            optimizer.zero_grad()
            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                cnn_optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                    att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()

            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                utils.clip_gradient(cnn_optimizer, opt.grad_clip)
                cnn_optimizer.step()

            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if not sc_flag:
                print(
                    "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print(
                    "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                    .format(iteration, epoch, model_out['reward'].mean(),
                            end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                                  iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward',
                                      model_out['reward'].mean(), iteration)

                loss_history[
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss',
                                  val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, cnn_model, infos, optimizer,
                                cnn_optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model,
                                    cnn_model,
                                    infos,
                                    optimizer,
                                    cnn_optimizer,
                                    append=str(iteration))

                if best_flag:
                    save_checkpoint(model,
                                    cnn_model,
                                    infos,
                                    optimizer,
                                    cnn_optimizer,
                                    append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, cnn_model, infos, optimizer, cnn_optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)

    # test model
    test_kwargs = {'split': 'test', 'dataset': opt.input_json}
    test_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        cnn_model, model, lw_model.crit, loader, test_kwargs)

    if opt.reduce_on_plateau:
        if 'CIDEr' in lang_stats:
            optimizer.scheduler_step(-lang_stats['CIDEr'])
        else:
            optimizer.scheduler_step(val_loss)
            # Write validation result into summary
    add_summary_value(tb_summary_writer, 'test loss', val_loss, iteration)
    if lang_stats is not None:
        for k, v in lang_stats.items():
            add_summary_value(tb_summary_writer, k, v, iteration)
    val_result_history[iteration] = {
        'loss': val_loss,
        'lang_stats': lang_stats,
        'predictions': predictions
    }
Exemplo n.º 6
0
def train(opt):

    # Load data
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    # Tensorboard summaries (they're great!)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    # Load pretrained model, info file, histories file
    infos = {}
    histories = {}
    if opt.start_from is not None:
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = ["rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)
    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    # Create model
    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)
    dp_model.train()

    # Loss function
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    # Optimizer and learning rate adjustment flag
    optimizer = utils.build_optimizer(model.parameters(), opt)
    update_lr_flag = True

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    # Training loop
    while True:

        # Update learning rate once per epoch
        if update_lr_flag:

            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        start = time.time()
        data = loader.get_batch('train')
        data_time = time.time() - start
        start = time.time()

        # Unpack data
        torch.cuda.synchronize()
        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        # Forward pass and loss
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        # Backward pass
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()

        # Print
        total_time = time.time() - start
        if iteration % opt.print_freq == 1:
            print('Read data:', time.time() - start)
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, data_time, total_time))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, data_time = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, np.mean(reward[:,0]), data_time, total_time))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Validate and save model
        if (iteration % opt.save_checkpoint_every == 0):

            # Evaluate model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Our metric is CIDEr if available, otherwise validation loss
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            # Save model in checkpoint path
            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
            checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to {}".format(checkpoint_path))
            optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
            torch.save(optimizer.state_dict(), optimizer_path)

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()
            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'infos_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

            # Save model to unique file if new best model
            if best_flag:
                model_fname = 'model-best-i{:05d}-score{:.4f}.pth'.format(
                    iteration, best_val_score)
                infos_fname = 'model-best-i{:05d}-infos.pkl'.format(iteration)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               model_fname)
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                with open(os.path.join(opt.checkpoint_path, infos_fname),
                          'wb') as f:
                    cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 7
0
def train(opt):
    # opt.use_att = utils.if_use_att(opt.caption_model)
    opt.use_att = True
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    print(opt.checkpoint_path)
    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    critic_loss_history = histories.get('critic_loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    variance_history = histories.get('variance_history', {})
    time_history = histories.get('time_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = model




    ######################### Actor-critic Training #####################################################################

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    #TODO: change this to a flag
    crit = utils.LanguageModelCriterion_binary()
    rl_crit = utils.RewardCriterion_binary()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(os.path.join(opt.start_from,"optimizer.pth")):
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    first_order = 0
    second_order = 0
    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        # Load data from train split (0)
        data = loader.get_batch('train')
        if data['bounds']['it_pos_now'] > 10000:
            loader.reset_iterator('train')
            continue
        dp_model.train()

        torch.cuda.synchronize()
        start = time.time()
        gen_result = None
        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks'], data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp
        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:,1:], masks[:,1:], dp_model.depth,
                        dp_model.vocab2code, dp_model.phi_list, dp_model.cluster_size)
        else:
            if opt.rl_type == 'sc':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'reinforce':
                gen_result, sample_logprobs = dp_model(fc_feats, att_feats, att_masks, opt={'sample_max':0}, mode='sample')
                reward = get_reward(data, gen_result, opt)
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda(), dp_model.depth)
            elif opt.rl_type == 'arm':
                loss = dp_model.get_arm_loss_binary_fast(fc_feats, att_feats, att_masks, opt, data, loader)
                #print(loss)
                reward = np.zeros([2,2])
            elif opt.rl_type == 'rf4':
                loss,_,_,_ = get_rf_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                # print(loss)
                reward = np.zeros([2, 2])
            elif opt.rl_type == 'ar':
                loss = get_ar_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = np.zeros([2,2])
            elif opt.rl_type =='mct_baseline':
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, mct_baseline = get_mct_loss(dp_model, fc_feats, att_feats, att_masks, data,
                                                                         opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                mct_baseline[mct_baseline < 0] = reward_cuda[mct_baseline < 0]
                if opt.arm_step_sample == 'greedy':
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, torch.from_numpy(reward).float().cuda() - mct_baseline)
            elif opt.rl_type == 'arsm_baseline':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_reward(data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                arm_baseline[arm_baseline < 0] = reward_cuda[arm_baseline < 0]
                if opt.arm_step_sample == 'greedy' and False:
                    sample_logprobs = sample_logprobs * probs
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda - arm_baseline)
            elif opt.rl_type == 'ars_indicator':
                opt.arm_as_baseline = 1
                opt.rf_demean = 0
                gen_result, sample_logprobs, probs, arm_baseline = get_arm_loss(dp_model, fc_feats, att_feats, att_masks, data, opt, loader)
                reward = get_self_critical_reward(dp_model, fc_feats, att_feats, att_masks, data, gen_result, opt)
                reward_cuda = torch.from_numpy(reward).float().cuda()
                loss = rl_crit(sample_logprobs, gen_result.data, reward_cuda * arm_baseline)
        if opt.mle_weights != 0:
            loss += opt.mle_weights * crit(dp_model(fc_feats, att_feats, labels, att_masks), labels[:, 1:], masks[:, 1:])
        #TODO make sure all sampling replaced by greedy for critic
        #### update the actor
        loss.backward()
        # with open(os.path.join(opt.checkpoint_path, 'embeddings.pkl'), 'wb') as f:
        #     cPickle.dump(list(dp_model.embed.parameters())[0].data.cpu().numpy(), f)
        ## compute variance
        gradient = torch.zeros([0]).cuda()
        for i in model.parameters():
            gradient = torch.cat((gradient, i.grad.view(-1)), 0)
        first_order = 0.999 * first_order + 0.001 * gradient
        second_order = 0.999 * second_order + 0.001 * gradient.pow(2)
        # print(torch.max(torch.abs(gradient)))
        variance = torch.mean(torch.abs(second_order - first_order.pow(2))).item()
        if opt.rl_type != 'arsm' or not sc_flag:
            utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        # ### update the critic

        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if (iteration % opt.losses_log_every == 0):
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
                print(opt.checkpoint_path)
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, variance = {:g}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), variance, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate', opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward', np.mean(reward), iteration)
                add_summary_value(tb_summary_writer, 'variance', variance, iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(reward)
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            variance_history[iteration] = variance
            time_history[iteration] = end - start


        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils_binary.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            if lang_stats is not None:
                for k,v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path, opt.critic_model + '_model.pth')
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['critic_loss_history'] = critic_loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                histories['variance_history'] = variance_history
                histories['time'] = time_history
                # histories['variance'] = 0
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 8
0
            # Assign the learning rate
            ens_opt = utils.manage_lr(epoch, ens_opt, val_losses)
            utils.scale_lr(optimizer, ens_opt.scale_lr)  # set the decayed rate
            lg.log_optimizer(ens_opt, optimizer)
            update_lr_flag = False
        # Load data from train split (0)
        data = loader.get_batch('train')
        torch.cuda.synchronize()
        start = time.time()
        # Forward the ensemble
        real_loss, loss = ens_model.step(data)
        optimizer.zero_grad()
        # // Move
        loss.backward()
        grad_norm = []
        grad_norm.append(utils.clip_gradient(optimizer, opt.grad_clip))
        optimizer.step()
        train_loss = loss.data[0]
        if np.isnan(train_loss):
            sys.exit('Loss is nan')
        train_real_loss = real_loss.data[0]
        try:
            train_kld_loss = kld_loss.data[0]
            train_recon_loss = recon_loss.data[0]
        except:
            pass
        #  grad_norm = [utils.get_grad_norm(optimizer)]
        torch.cuda.synchronize()
        end = time.time()
        losses = {'train_loss': train_loss, 'train_real_loss': train_real_loss}
Exemplo n.º 9
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.maxlen_sen
    opt.inc_seg = loader.inc_seg
    opt.seg_ix = loader.seg_ix
    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    score_list = []
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    best_val_score = None
    best_val_score = {}
    score_splits = ['val', 'test']
    score_type = ['Bleu_4', 'METEOR', 'CIDEr']
    for split_i in score_splits:
        for score_item in score_type:
            if split_i not in best_val_score.keys():
                best_val_score[split_i] = {}
            best_val_score[split_i][score_item] = 0.0
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', best_val_score)

    model = models.setup(opt)
    device_ids = [0, 1]

    torch.cuda.set_device(device_ids[0])
    model = nn.DataParallel(model, device_ids=device_ids)
    model = model.cuda()
    update_lr_flag = True
    # Assure in training mode
    model.module.train()
    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.module.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)
    #optimizer = nn.DataParallel(optimizer, device_ids=device_ids)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.module.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['labels'], data['x_phrase_mask_0'], data['x_phrase_mask_1'], \
               data['label_masks'], data['salicy_seg'], data['seg_mask']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, seq, phrase_mask_0, phrase_mask_1, masks, salicy_seg, seg_mask = tmp

        optimizer.zero_grad()
        remove_len = 2
        outputs, alphas = model.module(fc_feats, seq, phrase_mask_0,
                                       phrase_mask_1, masks, seg_mask,
                                       remove_len)
        loss = crit(outputs, seq[remove_len:, :].permute(1, 0),
                    masks[remove_len:, :].permute(1, 0))
        alphas = alphas.permute(1, 0, 2)
        salicy_seg = salicy_seg[:, :, :]
        seg_mask = seg_mask[:, :]
        if opt.salicy_hard == False:
            if opt.salicy_loss_type == 'l2':
                salicy_loss = (((((salicy_seg * seg_mask[:, :, None] -
                                   alphas * seg_mask[:, :, None])**2).sum(0)
                                 ).sum(-1))**(0.5)).mean()
            if opt.salicy_loss_type == 'kl':
                #alphas: len_sen, batch_size, num_frame
                salicy_loss = kullback_leibler2(
                    alphas * seg_mask[:, :, None],
                    salicy_seg * seg_mask[:, :, None])
                salicy_loss = (((salicy_loss *
                                 seg_mask[:, :, None]).sum(-1)).sum(0)).mean()
        elif opt.salicy_hard == True:
            #salicy len_sen, batch_size, num_frame
            salicy_loss = -torch.log((alphas * salicy_seg).sum(-1) + 1e-8)
            #salicy_loss len_sen, batch_size
            salicy_loss = ((salicy_loss * seg_mask).sum(0)).mean()
        loss = loss + opt.salicy_alpha * salicy_loss
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.module.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.module.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.dataset,
                'remove_len': remove_len
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats, score_list_i = eval_utils.eval_split(
                model.module, crit, loader, eval_kwargs)
            score_list.append(score_list_i)
            np.savetxt('./save/train_valid_test.txt', score_list, fmt='%.3f')
            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k in lang_stats.keys():
                    for v in lang_stats[k].keys():
                        add_summary_value(tf_summary_writer, k + v,
                                          lang_stats[k][v], iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['val']['CIDEr']
            else:
                current_score = -val_loss
            best_flag = {}
            for split_i in score_splits:
                for score_item in score_type:
                    if split_i not in best_flag.keys():
                        best_flag[split_i] = {}
                    best_flag[split_i][score_item] = False
            if True:  # if true
                for split_i in score_splits:
                    for score_item in score_type:
                        if best_val_score is None or lang_stats[split_i][
                                score_item] > best_val_score[split_i][
                                    score_item]:
                            best_val_score[split_i][score_item] = lang_stats[
                                split_i][score_item]
                            best_flag[split_i][score_item] = True

                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.module.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                for split_i in score_splits:
                    for score_item in score_type:
                        if best_flag[split_i][score_item]:
                            checkpoint_path = os.path.join(
                                opt.checkpoint_path, 'model-best_' + split_i +
                                '_' + score_item + '.pth')
                            torch.save(model.module.state_dict(),
                                       checkpoint_path)
                            print("model saved to {}".format(checkpoint_path))
                            with open(
                                    os.path.join(
                                        opt.checkpoint_path,
                                        'infos_' + split_i + '_' + score_item +
                                        '_' + opt.id + '-best.pkl'),
                                    'wb') as f:
                                cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 10
0
def train(rank, model, opt, optimizer=None):
    torch.manual_seed(opt.seed + rank)
    if opt.use_cuda:
        torch.cuda.manual_seed(opt.seed + rank)

    loader = DataLoader(opt)
    index_2_word = loader.get_vocab()
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(opt.start_from,
                             'infos_' + opt.load_model_id + '.pkl'),
                'rb') as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            # for checkme in need_be_same:
            #     assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    ss_prob_history = infos.get('ss_prob_history', {})

    sorted_lr = sorted(lr_history.items(), key=operator.itemgetter(1))
    if opt.load_lr and len(lr_history) > 0:
        opt.optim_rl_lr = sorted_lr[0][1] / opt.optim_rl_lr_ratio

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_image_id = infos.get('split_image_id', loader.split_image_id)

    entropy_reg = opt.entropy_reg
    best_val_score = 0
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    update_lr_flag = True

    if opt.caption_model == 'show_tell':
        crit = utils.LanguageModelCriterion(opt)
        rl_crit = utils.RewardCriterion(opt)

    elif opt.caption_model == 'review_net':
        crit = utils.ReviewNetCriterion(opt)
        rl_crit = utils.ReviewNetRewardCriterion(opt)

    elif opt.caption_model == 'recurrent_fusion_model':
        crit = utils.ReviewNetEnsembleCriterion(opt)
        rl_crit = utils.ReviewNetRewardCriterion(opt)

    else:
        raise Exception("caption_model not supported: {}".format(
            opt.caption_model))

    if optimizer is None:
        if opt.optim == 'adam':
            optimizer = optim.Adam(model.parameters(),
                                   lr=opt.optim_rl_lr,
                                   betas=(opt.optim_adam_beta1,
                                          opt.optim_adam_beta2),
                                   weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'rmsprop':
            optimizer = optim.RMSprop(model.parameters(),
                                      lr=opt.optim_rl_lr,
                                      momentum=opt.optim_momentum,
                                      alpha=opt.optim_rmsprop_alpha,
                                      weight_decay=opt.weight_decay)
        elif opt.optim == 'sgd':
            optimizer = optim.SGD(model.parameters(),
                                  lr=opt.optim_rl_lr,
                                  momentum=opt.optim_momentum,
                                  weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'adagrad':
            optimizer = optim.Adagrad(model.parameters(),
                                      lr=opt.optim_rl_lr,
                                      lr_decay=opt.optim_lr_decay,
                                      weight_decay=opt.optim_weight_decay)
        elif opt.optim == 'adadelta':
            optimizer = optim.Adadelta(model.parameters(),
                                       rho=opt.optim_rho,
                                       eps=opt.optim_epsilon,
                                       lr=opt.optim_rl_lr,
                                       weight_decay=opt.optim_weight_decay)
        else:
            raise Exception("optim not supported: {}".format(opt.feature_type))

        # Load the optimizer
        if opt.load_lr and vars(opt).get(
                'start_from', None) is not None and os.path.isfile(
                    os.path.join(opt.start_from,
                                 'optimizer_' + opt.load_model_id + '.pth')):
            optimizer.load_state_dict(
                torch.load(
                    os.path.join(opt.start_from,
                                 'optimizer_' + opt.load_model_id + '.pth')))
            utils.set_lr(optimizer, opt.optim_rl_lr)

    num_period_best = 0
    current_score = 0
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.optim_rl_lr * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.optim_rl_lr
            update_lr_flag = False

        start = time.time()
        data = loader.get_batch('train')

        if opt.use_cuda:
            torch.cuda.synchronize()

        if opt.feature_type == 'feat_array':
            fc_feat_array = data['fc_feats_array']
            att_feat_array = data['att_feats_array']
            assert (len(fc_feat_array) == len(att_feat_array))
            for feat_id in range(len(fc_feat_array)):
                if opt.use_cuda:
                    fc_feat_array[feat_id] = Variable(
                        torch.from_numpy(fc_feat_array[feat_id]),
                        requires_grad=False).cuda()
                    att_feat_array[feat_id] = Variable(
                        torch.from_numpy(att_feat_array[feat_id]),
                        requires_grad=False).cuda()
                else:
                    fc_feat_array[feat_id] = Variable(torch.from_numpy(
                        fc_feat_array[feat_id]),
                                                      requires_grad=False)
                    att_feat_array[feat_id] = Variable(torch.from_numpy(
                        att_feat_array[feat_id]),
                                                       requires_grad=False)

            tmp = [data['labels'], data['masks'], data['top_words']]
            if opt.use_cuda:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False).cuda()
                    for _ in tmp
                ]
            else:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False)
                    for _ in tmp
                ]
            labels, masks, top_words = tmp

        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['top_words']
            ]
            if opt.use_cuda:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False).cuda()
                    for _ in tmp
                ]
            else:
                tmp = [
                    Variable(torch.from_numpy(_), requires_grad=False)
                    for _ in tmp
                ]
            fc_feats, att_feats, labels, masks, top_words = tmp

        optimizer.zero_grad()

        if opt.caption_model == 'show_tell':
            gen_result, sample_logprobs, logprobs_all = model.sample(
                fc_feats, att_feats, {'sample_max': 0})
            rewards = get_rewards.get_self_critical_reward(
                index_2_word, model, fc_feats, att_feats, data, gen_result,
                opt)
            sample_logprobs_old = Variable(sample_logprobs.data,
                                           requires_grad=False)

            if opt.use_cuda:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float().cuda(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    sample_logprobs_old, opt)
            else:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    sample_logprobs_old, opt)

        elif opt.caption_model == 'recurrent_fusion_model':
            gen_result, sample_logprobs, logprobs_all, top_pred = model.sample(
                fc_feat_array, att_feat_array, {'sample_max': 0})
            rewards = get_rewards.get_self_critical_reward_feat_array(
                index_2_word, model, fc_feat_array, att_feat_array, data,
                gen_result, opt)
            sample_logprobs_old = Variable(sample_logprobs.data,
                                           requires_grad=False)

            if opt.use_cuda:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float().cuda(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)
            else:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)

        elif opt.caption_model == 'review_net':
            gen_result, sample_logprobs, logprobs_all, top_pred = model.sample(
                fc_feats, att_feats, {'sample_max': 0})
            rewards = get_rewards.get_self_critical_reward(
                index_2_word, model, fc_feats, att_feats, data, gen_result,
                opt)
            sample_logprobs_old = Variable(sample_logprobs.data,
                                           requires_grad=False)

            if opt.use_cuda:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float().cuda(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)
            else:
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(rewards).float(),
                             requires_grad=False), logprobs_all, entropy_reg,
                    top_pred, top_words, opt.reason_weight,
                    sample_logprobs_old, opt)

        else:
            raise Exception("caption_model not supported: {}".format(
                opt.caption_model))

        if opt.use_ppo and opt.ppo_k > 0:
            loss.backward(retain_graph=True)
        else:
            loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        if opt.use_ppo:
            for i in range(opt.ppo_k):
                print(i)
                optimizer.zero_grad()
                loss.backward(retain_graph=True)
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()

        if opt.use_cuda:
            torch.cuda.synchronize()
        end = time.time()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if iteration % opt.losses_log_every == 0:
            loss_history[iteration] = np.mean(rewards[:, 0])
            lr_history[iteration] = opt.current_lr

        # make evaluation on validation set, and save model
        if iteration % opt.save_checkpoint_every == 0:
            # eval model
            eval_kwargs = {
                'eval_split': 'val',
                'dataset': opt.input_json,
                'caption_model': opt.caption_model,
                'reason_weight': opt.reason_weight,
                'guiding_l1_penality': opt.guiding_l1_penality,
                'use_cuda': opt.use_cuda,
                'feature_type': opt.feature_type,
                'rank': rank
            }
            eval_kwargs.update(vars(opt))
            eval_kwargs['eval_split'] = 'val'
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }
            print("iter {} (epoch {}), val_loss = {:.3f}".format(
                iteration, epoch, val_loss))

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if best_val_score is None or current_score > best_val_score:
                best_val_score = current_score
                best_flag = True
                num_period_best = 1
            else:
                num_period_best = num_period_best + 1

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_image_id'] = loader.split_image_id
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['val_result_history'] = val_result_history
            infos['loss_history'] = loss_history
            infos['lr_history'] = lr_history
            infos['ss_prob_history'] = ss_prob_history
            infos['vocab'] = loader.get_vocab()
            with open(
                    os.path.join(
                        opt.checkpoint_path,
                        'rl_infos_' + opt.id + '_' + str(rank) + '.pkl'),
                    'wb') as f:
                cPickle.dump(infos, f)

            if best_flag:
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'rl_model_' + opt.id + '_' + str(rank) + '-best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'rl_optimizer_' + opt.id + '_' + str(rank) + '-best.pth')
                torch.save(optimizer.state_dict(), optimizer_path)
                print("model saved to {}".format(checkpoint_path))
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'rl_infos_' + opt.id + '_' +
                            str(rank) + '-best.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)

            if num_period_best >= opt.num_eval_no_improve:
                print('no improvement, exit')
                sys.exit()
        print("rank {}, iter {}, (epoch {}), avg_reward: {:.3f}, train_loss: {}, learning rate: {}, current cider: {:.3f}, best cider: {:.3f}, time: {:.3f}" \
              .format(rank, iteration, epoch, np.mean(rewards[:, 0]), train_loss, opt.current_lr, current_score, best_val_score, (end-start)))

        iteration += 1
        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 11
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    modelT = Att2inModel(opt)
    if vars(opt).get('start_from', None) is not None:
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        modelT.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))
        modelT.cuda()

    modelS = Att2inModel(opt)
    if vars(opt).get('start_from', None) is not None:
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        modelS.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))
        modelS.cuda()

    logger = Logger(opt)

    update_lr_flag = True
    # Assure in training mode
    modelT.train()
    modelS.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer_S = optim.Adam(modelS.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)
    optimizer_T = optim.Adam(modelT.parameters(),
                             lr=opt.learning_rate,
                             weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer_S.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            opt, sc_flag, update_lr_flag, modelS, optimizer_S = update_lr(
                opt, epoch, modelS, optimizer_S)
            opt, sc_flag, update_lr_flag, modelT, optimizer_T = update_lr(
                opt, epoch, modelT, optimizer_T)

        # Load data from train split (0)
        data = loader.get_batch('train', seq_per_img=opt.seq_per_img)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer_S.zero_grad()
        optimizer_T.zero_grad()
        if not sc_flag:
            loss = crit(modelS(fc_feats, labels), labels[:, 1:], masks[:, 1:])
            loss.backward()
        else:
            gen_result_S, sample_logprobs_S = modelS.sample(
                fc_feats, att_feats, {'sample_max': 0})
            reward_S = get_self_critical_reward_forTS(modelT, modelS, fc_feats,
                                                      att_feats, data,
                                                      gen_result_S, logger)

            gen_result_T, sample_logprobs_T = modelT.sample(
                fc_feats, att_feats, {'sample_max': 0})
            reward_T = get_self_critical_reward_forTS(modelS, modelT, fc_feats,
                                                      att_feats, data,
                                                      gen_result_T, logger)

            loss_S = rl_crit(
                sample_logprobs_S, gen_result_S,
                Variable(torch.from_numpy(reward_S).float().cuda(),
                         requires_grad=False))
            loss_T = rl_crit(
                sample_logprobs_T, gen_result_T,
                Variable(torch.from_numpy(reward_T).float().cuda(),
                         requires_grad=False))

            loss_S.backward()
            loss_T.backward()

            loss = loss_S + loss_T
            #reward = reward_S + reward_T

        utils.clip_gradient(optimizer_S, opt.grad_clip)
        utils.clip_gradient(optimizer_T, opt.grad_clip)
        optimizer_S.step()
        optimizer_T.step()
        train_loss = loss.data[0]

        torch.cuda.synchronize()
        end = time.time()

        if not sc_flag:
            log = "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start)
            logger.write(log)
        else:
            log = "iter {} (epoch {}), S_avg_reward = {:.3f}, T_avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration,  epoch, np.mean(reward_S[:,0]), np.mean(reward_T[:,0]), end - start)
            logger.write(log)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  modelS.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward_S',
                                      np.mean(reward_S[:, 0]), iteration)
                    add_summary_value(tf_summary_writer, 'avg_reward_T',
                                      np.mean(reward_T[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward_S[:, 0] + reward_T[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = modelS.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))

            val_loss, predictions, lang_stats = eval_utils.eval_split(
                modelS, crit, loader, logger, eval_kwargs)
            logger.write_dict(lang_stats)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss',
                                  val_loss, iteration)
                for k, v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'modelS.pth')
                torch.save(modelS.state_dict(), checkpoint_path)
                print("modelS saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'modelT.pth')
                torch.save(modelS.state_dict(), checkpoint_path)
                print("modelT saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'S_optimizer.pth')
                torch.save(optimizer_S.state_dict(), optimizer_path)
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'T_optimizer.pth')
                torch.save(optimizer_T.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'modelS-best.pth')
                    torch.save(modelS.state_dict(), checkpoint_path)
                    print("modelS saved to {}".format(checkpoint_path))
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'modelT-best.pth')
                    torch.save(modelT.state_dict(), checkpoint_path)
                    print("modelT saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 12
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:, 1:],
                    masks[:, 1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # Stop if reaching max epochs
        if epoch >= 8:
            break
Exemplo n.º 13
0
def train(opt):
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    infos = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = infos.get('val_result_history', {})
    loss_history = infos.get('loss_history', {})
    lr_history = infos.get('lr_history', {})
    ss_prob_history = infos.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats']]
        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats = tmp

        optimizer.zero_grad()

        gen_result, sample_logprobs = model.sample(fc_feats, att_feats,
                                                   {'sample_max': 0})

        rewards = get_rewards.get_self_critical_reward(model, fc_feats,
                                                       att_feats, data,
                                                       gen_result)
        loss = rl_crit(
            sample_logprobs, gen_result,
            Variable(torch.from_numpy(rewards).float().cuda(),
                     requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, np.mean(rewards[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = np.mean(rewards[:, 0])
            lr_history[iteration] = opt.current_lr

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, crit, loader, eval_kwargs)

            # Write validation result into summary
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['val_result_history'] = val_result_history
                infos['loss_history'] = loss_history
                infos['lr_history'] = lr_history
                infos['ss_prob_history'] = ss_prob_history
                infos['vocab'] = loader.get_vocab()
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 14
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)

    from dataloader import DataLoader
    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.vocab_ccg_size = loader.vocab_ccg_size
    opt.seq_length = loader.seq_length

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme
        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)

    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    cnn_model = utils.build_cnn(opt)
    cnn_model.cuda()

    model = models.setup(opt)
    model.cuda()
    # model = DataParallel(model)

    if vars(opt).get('start_from', None) is not None:
        # check if all necessary files exist
        assert os.path.isdir(
            opt.start_from), " %s must be a a path" % opt.start_from
        assert os.path.isfile(
            os.path.join(opt.start_from, "infos_" + opt.id + ".pkl")
        ), "infos.pkl file does not exist in path %s" % opt.start_from
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    update_lr_flag = True
    model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    multilabel_crit = nn.MultiLabelSoftMarginLoss().cuda()
    #    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)
    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate)
    if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
        print('finetune mode')
        cnn_optimizer = optim.Adam([\
            {'params': module.parameters()} for module in cnn_model._modules.values()[5:]\
            ], lr=opt.cnn_learning_rate, weight_decay=opt.cnn_weight_decay)

    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        if os.path.isfile(os.path.join(opt.start_from, 'optimizer.pth')):
            optimizer.load_state_dict(
                torch.load(os.path.join(opt.start_from, 'optimizer.pth')))
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
            if os.path.isfile(os.path.join(opt.start_from,
                                           'optimizer-cnn.pth')):
                cnn_optimizer.load_state_dict(
                    torch.load(
                        os.path.join(opt.start_from, 'optimizer-cnn.pth')))

    eval_kwargs = {'split': 'val', 'dataset': opt.input_json, 'verbose': True}
    eval_kwargs.update(vars(opt))
    val_loss, predictions, lang_stats = eval_utils.eval_split(
        cnn_model, model, crit, loader, eval_kwargs, True)
    epoch_start = time.time()
    while True:
        if update_lr_flag:
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
                #model.module.ss_prob = opt.ss_prob
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
            else:
                sc_flag = False

            # Update the training stage of cnn
            for p in cnn_model.parameters():
                p.requires_grad = True
            # Fix the first few layers:
            for module in cnn_model._modules.values()[:5]:
                for p in module.parameters():
                    p.requires_grad = False
            cnn_model.train()
            update_lr_flag = False

        cnn_model.apply(utils.set_bn_fix)
        cnn_model.apply(utils.set_bn_eval)

        start = time.time()
        torch.cuda.synchronize()
        data = loader.get_batch('train')
        if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:

            multilabels = [
                data['detection_infos'][i]['label']
                for i in range(len(data['detection_infos']))
            ]

            tmp = [
                data['labels'], data['masks'],
                np.array(multilabels, dtype=np.int16)
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            labels, masks, multilabels = tmp
            images = data[
                'images']  # it cannot be turned into tensor since different sizes.
            _fc_feats_2048 = []
            _fc_feats_81 = []
            _att_feats = []
            for i in range(loader.batch_size):
                x = Variable(torch.from_numpy(images[i]),
                             requires_grad=False).cuda()
                x = x.unsqueeze(0)
                att_feats, fc_feats_81 = cnn_model(x)
                fc_feats_2048 = att_feats.mean(3).mean(2).squeeze()
                att_feats = F.adaptive_avg_pool2d(att_feats,
                                                  [14, 14]).squeeze().permute(
                                                      1, 2, 0)  #(0, 2, 3, 1)
                _fc_feats_2048.append(fc_feats_2048)
                _fc_feats_81.append(fc_feats_81)
                _att_feats.append(att_feats)
            _fc_feats_2048 = torch.stack(_fc_feats_2048)
            _fc_feats_81 = torch.stack(_fc_feats_81)
            _att_feats = torch.stack(_att_feats)
            att_feats = _att_feats.unsqueeze(1).expand(*((_att_feats.size(0), loader.seq_per_img,) + \
                                                           _att_feats.size()[1:])).contiguous().view(*((_att_feats.size(0) * loader.seq_per_img,) + \
                                                           _att_feats.size()[1:]))
            fc_feats_2048 = _fc_feats_2048.unsqueeze(1).expand(*((_fc_feats_2048.size(0), loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:])).contiguous().view(*((_fc_feats_2048.size(0) * loader.seq_per_img,) + \
                                                          _fc_feats_2048.size()[1:]))
            fc_feats_81 = _fc_feats_81
            #
            cnn_optimizer.zero_grad()
        else:

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks']
            ]
            tmp = [
                Variable(torch.from_numpy(_), requires_grad=False).cuda()
                for _ in tmp
            ]
            fc_feats, att_feats, labels, masks = tmp

        optimizer.zero_grad()

        if not sc_flag:
            loss1 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss = 0.8 * loss1 + 0.2 * loss2.float()
        else:
            gen_result, sample_logprobs = model.sample(fc_feats_2048,
                                                       att_feats,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats_2048, att_feats,
                                              data, gen_result)
            loss1 = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))
            loss2 = multilabel_crit(fc_feats_81.double(), multilabels.double())
            loss3 = crit(model(fc_feats_2048, att_feats, labels),
                         labels[:, 1:], masks[:, 1:])
            loss = 0.995 * loss1 + 0.005 * (loss2.float() + loss3)
        loss.backward()

        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()

        train_loss = loss.data[0]
        mle_loss = loss1.data[0]
        multilabel_loss = loss2.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, mle_loss, multilabel_loss, train_loss, end - start))

        if sc_flag and iteration % 2500 == 0:
            print("iter {} (epoch {}), avg_reward = {:.3f}, mle_loss = {:.3f}, multilabel_loss = {:.3f}, train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), mle_loss, multilabel_loss, train_loss, end - start))
        iteration += 1
        if (iteration % opt.losses_log_every == 0):
            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        if (iteration % opt.save_checkpoint_every == 0):
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': True
            }
            eval_kwargs.update(vars(opt))

            if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, True)
            else:
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    cnn_model, model, crit, loader, eval_kwargs, False)

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))

                cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-cnn.pth')
                torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                print("cnn model saved to {}".format(cnn_checkpoint_path))

                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                if opt.finetune_cnn_after != -1 and epoch >= opt.finetune_cnn_after:
                    cnn_optimizer_path = os.path.join(opt.checkpoint_path,
                                                      'optimizer-cnn.pth')
                    torch.save(cnn_optimizer.state_dict(), cnn_optimizer_path)

                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))

                    cnn_checkpoint_path = os.path.join(opt.checkpoint_path,
                                                       'model-cnn-best.pth')
                    torch.save(cnn_model.state_dict(), cnn_checkpoint_path)
                    print("cnn model saved to {}".format(cnn_checkpoint_path))

                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True
            print("epoch: " + str(epoch) + " during: " +
                  str(time.time() - epoch_start))
            epoch_start = time.time()

        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 15
0
def train(opt):
    print("=================Training Information==============")
    print("start from {}".format(opt.start_from))
    print("box from {}".format(opt.input_box_dir))
    print("input json {}".format(opt.input_json))
    print("attributes from {}".format(opt.input_att_dir))
    print("features from {}".format(opt.input_fc_dir))
    print("batch size ={}".format(opt.batch_size))
    print("#GPU={}".format(torch.cuda.device_count()))
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box:
        opt.att_feat_size = opt.att_feat_size + 5

    acc_steps = getattr(opt, 'acc_steps', 1)
    name_append = opt.name_append
    if len(name_append) > 0 and name_append[0] != '-':
        name_append = '_' + name_append

    loader = DataLoader(opt)

    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length
    opt.write_summary = write_summary
    if opt.write_summary:
        print("write summary to {}".format(opt.checkpoint_path))
        tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}

    if opt.start_from is not None:
        # open old infos and check if models are compatible
        infors_path = os.path.join(opt.start_from,
                                   'infos' + name_append + '.pkl')
        print("Load model information {}".format(infors_path))
        with open(infors_path, 'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        histories_path = os.path.join(opt.start_from,
                                      'histories_' + name_append + '.pkl')
        if os.path.isfile(histories_path):
            with open(histories_path, 'rb') as f:
                histories = utils.pickle_load(f)
    else:  # start from scratch
        print("Initialize training process from all begining")
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    #  sanity check for the saved model name has a correct index
    if opt.name_append.isdigit() and int(opt.name_append) < 100:
        assert int(
            opt.name_append
        ) == epoch, "dismatch in the model index and the real epoch number"
        epoch += 1

    print(
        "==================start from {} epoch================".format(epoch))
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})
    # pdb.set_trace()
    loader.iterators = infos.get('iterators', loader.iterators)
    start_Img_idx = loader.iterators['train']
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    model = models.setup(opt).cuda()
    del opt.vocab
    dp_model = torch.nn.DataParallel(model)
    lw_model = LossWrapper(model, opt)  # wrap loss into model
    dp_lw_model = torch.nn.DataParallel(lw_model)

    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'aoa'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer_path = os.path.join(opt.start_from,
                                      'optimizer' + name_append + '.pth')
        if os.path.isfile(optimizer_path):
            print("Loading optimizer............")
            optimizer.load_state_dict(torch.load(optimizer_path))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '_' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("Save model state to {}".format(checkpoint_path))

        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        print("Save model optimizer to {}".format(optimizer_path))

        with open(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append)), 'wb') as f:
            utils.pickle_dump(infos, f)
            print("Save training information to {}".format(
                os.path.join(opt.checkpoint_path,
                             'infos' + '%s.pkl' % (append))))

        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)
                print("Save training historyes to {}".format(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '%s.pkl' % (append))))

    try:
        while True:
            # pdb.set_trace()
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False
            print("{}th Epoch Training starts now!".format(epoch))
            with tqdm(total=len(loader.split_ix['train']),
                      initial=start_Img_idx) as pbar:
                for i in range(start_Img_idx, len(loader.split_ix['train']),
                               opt.batch_size):
                    # import ipdb; ipdb.set_trace()
                    start = time.time()
                    if (opt.use_warmup
                            == 1) and (iteration < opt.noamopt_warmup):
                        opt.current_lr = opt.learning_rate * (
                            iteration + 1) / opt.noamopt_warmup
                        utils.set_lr(optimizer, opt.current_lr)
                    # Load data from train split (0)
                    data = loader.get_batch('train')
                    # print('Read data:', time.time() - start)

                    if (iteration % acc_steps == 0):
                        optimizer.zero_grad()

                    torch.cuda.synchronize()
                    start = time.time()
                    tmp = [
                        data['fc_feats'], data['att_feats'], data['labels'],
                        data['masks'], data['att_masks']
                    ]
                    tmp = [_ if _ is None else _.cuda() for _ in tmp]
                    fc_feats, att_feats, labels, masks, att_masks = tmp

                    model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                            att_masks, data['gts'],
                                            torch.arange(0, len(data['gts'])),
                                            sc_flag)

                    loss = model_out['loss'].mean()
                    loss_sp = loss / acc_steps

                    loss_sp.backward()
                    if ((iteration + 1) % acc_steps == 0):
                        utils.clip_gradient(optimizer, opt.grad_clip)
                        optimizer.step()
                    torch.cuda.synchronize()
                    train_loss = loss.item()
                    end = time.time()
                    # if not sc_flag:
                    #     print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                    #         .format(iteration, epoch, train_loss, end - start))
                    # else:
                    #     print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                    #         .format(iteration, epoch, model_out['reward'].mean(), end - start))
                    if not sc_flag:
                        pbar.set_description(
                            "iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch, train_loss, end - start))
                    else:
                        pbar.set_description(
                            "iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}"
                            .format(iteration, epoch,
                                    model_out['reward'].mean(), end - start))

                    # Update the iteration and epoch
                    iteration += 1
                    pbar.update(opt.batch_size)
                    if data['bounds']['wrapped']:
                        # save after each epoch
                        save_checkpoint(model,
                                        infos,
                                        optimizer,
                                        append=str(epoch))
                        epoch += 1
                        # infos['epoch'] = epoch
                        epoch_done = True

                    # Write validation result into summary
                    if (iteration % opt.losses_log_every
                            == 0) and opt.write_summary:
                        add_summary_value(tb_summary_writer, 'loss/train_loss',
                                          train_loss, iteration)
                        if opt.noamopt:
                            opt.current_lr = optimizer.rate()
                        elif opt.reduce_on_plateau:
                            opt.current_lr = optimizer.current_lr
                        add_summary_value(tb_summary_writer,
                                          'hyperparam/learning_rate',
                                          opt.current_lr, iteration)
                        add_summary_value(
                            tb_summary_writer,
                            'hyperparam/scheduled_sampling_prob',
                            model.ss_prob, iteration)
                        if sc_flag:
                            add_summary_value(tb_summary_writer, 'avg_reward',
                                              model_out['reward'].mean(),
                                              iteration)

                        loss_history[
                            iteration] = train_loss if not sc_flag else model_out[
                                'reward'].mean()
                        lr_history[iteration] = opt.current_lr
                        ss_prob_history[iteration] = model.ss_prob

                    # update infos
                    infos['iter'] = iteration
                    infos['epoch'] = epoch
                    infos['iterators'] = loader.iterators
                    infos['split_ix'] = loader.split_ix

                    # make evaluation on validation set, and save model
                    # TODO modify it to evaluate by each epoch
                    # ipdb.set_trace()
                    if (iteration % opt.save_checkpoint_every
                            == 0) and eval_ and epoch > 20:
                        model_path = os.path.join(
                            opt.checkpoint_path,
                            'model_itr%s.pth' % (iteration))
                        eval_kwargs = {
                            'split': 'val',
                            'dataset': opt.input_json,
                            'model': model_path
                        }
                        eval_kwargs.update(vars(opt))
                        val_loss, predictions, lang_stats = eval_utils.eval_split(
                            dp_model, lw_model.crit, loader, eval_kwargs)

                        if opt.reduce_on_plateau:
                            if 'CIDEr' in lang_stats:
                                optimizer.scheduler_step(-lang_stats['CIDEr'])
                            else:
                                optimizer.scheduler_step(val_loss)

                        # Write validation result into summary
                        if opt.write_summary:
                            add_summary_value(tb_summary_writer,
                                              'loss/validation loss', val_loss,
                                              iteration)

                            if lang_stats is not None:
                                bleu_dict = {}
                                for k, v in lang_stats.items():
                                    if 'Bleu' in k:
                                        bleu_dict[k] = v
                                if len(bleu_dict) > 0:
                                    tb_summary_writer.add_scalars(
                                        'val/Bleu', bleu_dict, epoch)

                                for k, v in lang_stats.items():
                                    if 'Bleu' not in k:
                                        add_summary_value(
                                            tb_summary_writer, 'val/' + k, v,
                                            iteration)
                        val_result_history[iteration] = {
                            'loss': val_loss,
                            'lang_stats': lang_stats,
                            'predictions': predictions
                        }

                        # Save model if is improving on validation result
                        if opt.language_eval == 1:
                            current_score = lang_stats['CIDEr']
                        else:
                            current_score = -val_loss

                        best_flag = False

                        if best_val_score is None or current_score > best_val_score:
                            best_val_score = current_score
                            best_flag = True

                        # Dump miscalleous informations
                        infos['best_val_score'] = best_val_score
                        histories['val_result_history'] = val_result_history
                        histories['loss_history'] = loss_history
                        histories['lr_history'] = lr_history
                        histories['ss_prob_history'] = ss_prob_history

                        # save_checkpoint(model, infos, optimizer, histories, append=str(iteration))
                        save_checkpoint(model, infos, optimizer, histories)
                        # if opt.save_history_ckpt:
                        #     save_checkpoint(model, infos, optimizer, append=str(iteration))

                        if best_flag:
                            save_checkpoint(model,
                                            infos,
                                            optimizer,
                                            append='best')
                            print(
                                "update best model at {} iteration--{} epoch".
                                format(iteration, epoch))

                    start_Img_idx = 0
                    # if epoch_done: # go through the set, start a new epoch loop
                    #     break
            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                print("epoch {} break all".format(epoch))
                save_checkpoint(model, infos, optimizer)
                tb_summary_writer.close()
                print("============{} Training Done !==============".format(
                    'Refine' if opt.use_test or opt.use_val else ''))
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer, append='_interrupt')
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
Exemplo n.º 16
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    ac = 0

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(
                    opt.checkpoint_path, 'infos_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(
                    opt.checkpoint_path, 'histories_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')):
            with open(
                    os.path.join(
                        opt.checkpoint_path, 'histories_' + opt.id +
                        format(int(opt.start_from), '04') + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()

    #dp_model = torch.nn.DataParallel(model)
    #dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    for name, param in model.named_parameters():
        print(name)

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    CE_ac = utils.CE_ac()

    optim_para = model.parameters()
    optimizer = utils.build_optimizer(optim_para, opt)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(
                opt.checkpoint_path, 'optimizer' + opt.id +
                format(int(opt.start_from), '04') + '.pth')):
        optimizer.load_state_dict(
            torch.load(
                os.path.join(
                    opt.checkpoint_path, 'optimizer' + opt.id +
                    format(int(opt.start_from), '04') + '.pth')))

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])
    sim_lambda = opt.sim_lambda
    reset_optimzer_index = 1
    while True:
        if opt.self_critical_after != -1 and epoch >= opt.self_critical_after and reset_optimzer_index:
            opt.learning_rate_decay_start = opt.self_critical_after
            opt.learning_rate_decay_rate = opt.learning_rate_decay_rate_rl
            opt.learning_rate = opt.learning_rate_rl
            reset_optimzer_index = 0

        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['labels'], data['masks'], data['mods']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        labels, masks, mods = tmp

        tmp = [
            data['att_feats'], data['att_masks'], data['attr_feats'],
            data['attr_masks'], data['rela_feats'], data['rela_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        att_feats, att_masks, attr_feats, attr_masks, rela_feats, rela_masks = tmp

        rs_data = {}
        rs_data['att_feats'] = att_feats
        rs_data['att_masks'] = att_masks
        rs_data['attr_feats'] = attr_feats
        rs_data['attr_masks'] = attr_masks
        rs_data['rela_feats'] = rela_feats
        rs_data['rela_masks'] = rela_masks

        if not sc_flag:
            logits, cw_logits = dp_model(rs_data, labels)
            ac = CE_ac(logits, labels[:, 1:], masks[:, 1:])
            print('ac :{0}'.format(ac))
            loss_lan = crit(logits, labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs, cw_logits = dp_model(
                rs_data, opt={'sample_max': 0}, mode='sample')
            reward = get_self_critical_reward(dp_model, rs_data, data,
                                              gen_result, opt)
            loss_lan = rl_crit(sample_logprobs, gen_result.data,
                               torch.from_numpy(reward).float().cuda())

        loss_cw = crit(cw_logits, mods[:, 1:], masks[:, 1:])
        ac2 = CE_ac(cw_logits, mods[:, 1:], masks[:, 1:])
        print('ac :{0}'.format(ac2))
        if epoch < opt.step2_train_after:
            loss = loss_lan + sim_lambda * loss_cw
        else:
            loss = loss_lan

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            train_loss_lan = loss_lan.item()
            train_loss_cw = loss_cw.item()
            end = time.time()

            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, np.mean(reward[:, 0]), end - start))
                print("train_loss_lan = {:.3f}, train_loss_cw = {:.3f}" \
                      .format(train_loss_lan, train_loss_cw))
            print('lr:{0}'.format(opt.current_lr))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'train_loss_lan',
                              train_loss_lan, iteration)
            add_summary_value(tb_summary_writer, 'train_loss_cw',
                              train_loss_cw, iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            add_summary_value(tb_summary_writer, 'ac', ac, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            # eval model
            eval_kwargs = {'split': 'test', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            #val_loss, predictions, lang_stats = eval_utils_rs3.eval_split(dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            # add_summary_value(tb_summary_writer, 'validation loss', val_loss, iteration)
            # if lang_stats is not None:
            #     for k,v in lang_stats.items():
            #         add_summary_value(tb_summary_writer, k, v, iteration)
            # val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            # if opt.language_eval == 1:
            #     current_score = lang_stats['CIDEr']
            # else:
            #     current_score = - val_loss
            current_score = 0

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'model' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'optimizer' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'infos_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path, 'histories_' + opt.id +
                            format(int(save_id), '04') + '.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
def train(opt):
    # setup dataloader
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    #set the checkpoint path
    opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)
    isExists = os.path.exists(opt.checkpoint_path)
    if not isExists:
        os.makedirs(opt.checkpoint_path)
        os.makedirs(opt.checkpoint_path + '/logs')
        print(opt.checkpoint_path + ' creating !')
    else:
        print(opt.checkpoint_path + ' already exists!')

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(
                os.path.join(
                    opt.checkpoint_path, 'infos_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "att_feat_size", "rnn_size",
                "input_encoding_size"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(
                    opt.checkpoint_path, 'histories_' + opt.id +
                    format(int(opt.start_from), '04') + '.pkl')):
            with open(
                    os.path.join(
                        opt.checkpoint_path, 'histories_' + opt.id +
                        format(int(opt.start_from), '04') + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)
    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    word_loss_history = histories.get('word_loss_history', {})
    MAD_loss_history = histories.get('MAD_loss_history', {})
    SAP_loss_history = histories.get('SAP_loss_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    lr_history = histories.get('lr_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    #set up model, assure in training mode
    threshold = opt.threshold
    sc_flag = False
    num_gpu = opt.num_gpu

    model = models.setup(opt).cuda(device=0)
    model.train()
    update_lr_flag = True
    dp_model = torch.nn.parallel.DataParallel(model)

    optimizer = optim.Adam(model.parameters(),
                           opt.learning_rate,
                           (opt.optim_alpha, opt.optim_beta),
                           opt.optim_epsilon,
                           weight_decay=opt.weight_decay)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(
                opt.checkpoint_path, 'optimizer' + opt.id +
                format(int(opt.start_from), '04') + '.pth')):
        optimizer.load_state_dict(
            torch.load(
                os.path.join(
                    opt.checkpoint_path, 'optimizer' + opt.id +
                    format(int(opt.start_from), '04') + '.pth')))

    if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
        frac = (epoch - opt.scheduled_sampling_start
                ) // opt.scheduled_sampling_increase_every
        opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                          opt.scheduled_sampling_max_prob)
        model.ss_prob = opt.ss_prob

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0

    subsequent_mat = np.load('data/markov_mat.npy')
    subsequent_mat = torch.from_numpy(subsequent_mat).cuda(device=0).float()
    subsequent_mat_all = subsequent_mat.clone()
    # for multi-GPU training
    for i in range(opt.num_gpu - 1):
        subsequent_mat_all = torch.cat([subsequent_mat_all, subsequent_mat],
                                       dim=0)

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            for group in optimizer.param_groups:
                group['lr'] = opt.current_lr

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if sc_flag == False and opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                print('initializing CIDEr scorer...')
                s = time.time()
                global CiderD_scorer
                if (CiderD_scorer is None):
                    CiderD_scorer = CiderD(df=opt.cached_tokens)
                    #takes about 30s
                    print('initlizing CIDEr scorers in {:3f}s'.format(
                        time.time() - s))
                sc_flag = True
                opt.learning_rate_decay_every = opt.learning_rate_decay_every * 2  #default 5 for xe, 10 for scst

            update_lr_flag = False

        print('current_lr is {}'.format(opt.current_lr))
        start = time.time()
        data = loader.get_batch('train', opt.batch_size)

        torch.cuda.synchronize()

        fc_feats = None
        att_feats = None

        tmp = [
            data['fc_feats'], data['labels'], data['masks'], data['att_feats'],
            data['attr_labels'], data['subsequent_labels']
        ]
        tmp = [
            _ if _ is None else torch.from_numpy(_).cuda(device=0) for _ in tmp
        ]
        fc_feats, labels, masks, att_feats, attr_labels, subsequent_labels = tmp

        #convert 1-1000 to 0-999 (perhaps done in preprocessing)
        subsequent_labels = subsequent_labels - 1
        subsequent_mask = (subsequent_labels[:, 1:] >= 0).float()
        subsequent_labels = torch.where(
            subsequent_labels > 0, subsequent_labels,
            torch.zeros_like(subsequent_labels).int().cuda(device=0))

        print('Read and process data:', time.time() - start)

        if not sc_flag:
            SAP_loss, word_loss, MAD_loss = dp_model(
                fc_feats, att_feats, labels, masks, attr_labels,
                subsequent_labels, subsequent_mask, subsequent_mat_all)
            SAP_loss = SAP_loss.mean()
            word_loss = word_loss.mean()
            MAD_loss = MAD_loss.mean()
            accumulate_iter = accumulate_iter + 1
            loss = (word_loss + 0.2 * SAP_loss +
                    0.2 * MAD_loss) / opt.accumulate_number
            loss.backward()
        else:
            st = time.time()
            sm = torch.zeros([num_gpu, 1]).cuda(
                device=0)  #indexs for sampling by probabilities
            gen_result, sample_logprobs, _ = dp_model(fc_feats,
                                                      att_feats,
                                                      attr_labels,
                                                      subsequent_mat_all,
                                                      sm,
                                                      mode='sample')
            dp_model.eval()
            with torch.no_grad():
                greedy_res, _, _ = dp_model(fc_feats,
                                            att_feats,
                                            attr_labels,
                                            subsequent_mat_all,
                                            mode='sample')
            dp_model.train()
            ed = time.time()
            print('GPU time is : {}s'.format(ed - st))
            reward = get_self_critical_reward(gen_result, greedy_res,
                                              data['gts'])
            word_loss = dp_model(sample_logprobs,
                                 gen_result.data,
                                 torch.from_numpy(reward).float().cuda(),
                                 mode='scst_forward')
            word_loss = word_loss.mean()

            loss = word_loss

            #forward to minimize SAP loss and MAD loss
            SAP_loss, _, MAD_loss = dp_model(fc_feats, att_feats, labels,
                                             masks, attr_labels,
                                             subsequent_labels,
                                             subsequent_mask,
                                             subsequent_mat_all)
            SAP_loss = SAP_loss.mean()
            MAD_loss = MAD_loss.mean()
            loss = loss + 0.2 * SAP_loss + 0.2 * MAD_loss
            loss.backward()
            accumulate_iter = accumulate_iter + 1

        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()

            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()
            #you can record the training log if you need
            #text_file = open(opt.checkpoint_path+'/logs/train_log_'+opt.id+'.txt', "aw")
            if not sc_flag:
                print("iter {} (epoch {}), SAP_loss = {:.3f}, word_loss = {:.3f}, MAD_loss = {:.3f} time/batch = {:.3f}" \
                      .format(iteration, epoch,SAP_loss, word_loss,MAD_loss, end - start))
                #text_file.write("iter {} (epoch {}),SAP_loss = {:.3f}, word_loss {:.3f}, MAD_loss {:.3f},time/batch = {:.3f}\n" \
                #      .format(iteration, epoch,SAP_loss, word_loss, MAD_loss, end - start))

            else:
                print("iter {} (epoch {}),SAP_loss = {:.3f}, avg_reward = {:.3f},MAD_loss = {:.3f} time/batch = {:.3f}" \
                      .format(iteration, epoch,SAP_loss,np.mean(reward[:, 0]),MAD_loss, end - start))
                #text_file.write("iter {} (epoch {}), avg_reward = {:.3f} MAD_loss ={:.3f}, time/batch = {:.3f}\n" \
                #      .format(iteration, epoch, np.mean(reward[:, 0]), MAD_loss, end - start))
            #text_file.close()
        torch.cuda.synchronize()

        # Update the iteration and epoch

        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            add_summary_value(tb_summary_writer, 'word_loss', word_loss.item(),
                              iteration)
            add_summary_value(tb_summary_writer, 'MAD_loss', MAD_loss.item(),
                              iteration)
            add_summary_value(tb_summary_writer, 'SAP_loss', SAP_loss.item(),
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            word_loss_history[iteration] = word_loss.item()
            SAP_loss_history[iteration] = SAP_loss.item()
            MAD_loss_history[iteration] = MAD_loss.item()

            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (accumulate_iter % opt.accumulate_number == 0):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'num_images': -1,
                'index_eval': 1,
                'id': opt.id,
                'beam': opt.beam,
                'verbose_loss': 1,
                'checkpoint_path': opt.checkpoint_path
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats, precision, recall = eval_utils.eval_split(
                dp_model, loader, subsequent_mat_all, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            #save lang stats
            f_lang = open(
                opt.checkpoint_path + '/logs/lang_' + opt.id + '.txt', 'aw')
            f_lang.write(
                str(iteration) + ' ' +
                str(iteration / opt.save_checkpoint_every) + '\n')
            f_lang.write('val loss ' + str(val_loss) + '\n')
            for key_lang in lang_stats:
                f_lang.write(key_lang + ' ' + str(lang_stats[key_lang]) + '\n')
            f_lang.write('precision ' + str(precision) + ' recall ' +
                         str(recall) + '\n')
            f_lang.close()

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            save_id = iteration / opt.save_checkpoint_every

            if best_val_score is None or current_score > best_val_score or current_score > threshold:
                best_val_score = current_score
                best_flag = True

                ##only save the improved models or when the CIDEr-D is larger than a given threshold
                checkpoint_path = os.path.join(
                    opt.checkpoint_path,
                    'model' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(
                    opt.checkpoint_path,
                    'optimizer' + opt.id + format(int(save_id), '04') + '.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                #record the lang stats for saved mdoel
                f_lang = open(
                    opt.checkpoint_path + '/logs/Best_lang_' + opt.id + '.txt',
                    'aw')
                f_lang.write(
                    str(iteration) + ' ' +
                    str(iteration / opt.save_checkpoint_every) + '\n')
                f_lang.write('val loss ' + str(val_loss) + '\n')
                for key_lang in lang_stats:
                    f_lang.write(key_lang + ' ' + str(lang_stats[key_lang]) +
                                 '\n')
                f_lang.write('precision ' + str(precision) + ' recall ' +
                             str(recall) + '\n')
                f_lang.close()

            # Dump miscalleous informations
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix
            infos['best_val_score'] = best_val_score
            infos['opt'] = opt
            infos['vocab'] = loader.get_vocab()

            histories['val_result_history'] = val_result_history
            histories['loss_history'] = loss_history
            histories['word_loss_history'] = loss_history
            histories['MAD_loss_history'] = MAD_loss_history
            histories['SAP_loss_history'] = SAP_loss_history

            histories['lr_history'] = lr_history
            histories['ss_prob_history'] = ss_prob_history

            with open(
                    os.path.join(
                        opt.checkpoint_path, 'infos_' + opt.id +
                        format(int(save_id), '04') + '.pkl'), 'wb') as f:
                cPickle.dump(infos, f)
            with open(
                    os.path.join(
                        opt.checkpoint_path, 'histories_' + opt.id +
                        format(int(save_id), '04') + '.pkl'), 'wb') as f:
                cPickle.dump(histories, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 18
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    acc_steps = getattr(opt, 'acc_steps', 1)

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)
    else:
        infos['iter'] = 0
        infos['epoch'] = 0
        infos['iterators'] = loader.iterators
        infos['split_ix'] = loader.split_ix
        infos['vocab'] = loader.get_vocab()
    infos['opt'] = opt

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    opt.vocab = loader.get_vocab()
    model = models.setup(opt)
    model = load_para(model, os.path.join('./log/log_aoanet_rl', 'model.pth'))
    if True:
        del opt.vocab
        dp_model = model
        lw_model = LossWrapper(model, opt)
        dp_lw_model = torch.nn.DataParallel(lw_model)

        epoch_done = True
        # Assure in training mode
        dp_lw_model.train()
    else:
        model = model.cuda()
        del opt.vocab
        dp_model = torch.nn.DataParallel(model)
        lw_model = LossWrapper(model, opt)
        dp_lw_model = torch.nn.DataParallel(lw_model)

        epoch_done = True
        # Assure in training mode
        dp_lw_model.train()

    if opt.noamopt:
        assert opt.caption_model in [
            'transformer', 'aoa'
        ], 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    def save_checkpoint(model, infos, optimizer, histories=None, append=''):
        if len(append) > 0:
            append = '-' + append
        # if checkpoint_path doesn't exist
        if not os.path.isdir(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        checkpoint_path = os.path.join(opt.checkpoint_path,
                                       'model%s.pth' % (append))
        torch.save(model.state_dict(), checkpoint_path)
        print("model saved to {}".format(checkpoint_path))
        optimizer_path = os.path.join(opt.checkpoint_path,
                                      'optimizer%s.pth' % (append))
        torch.save(optimizer.state_dict(), optimizer_path)
        with open(
                os.path.join(opt.checkpoint_path,
                             'infos_' + opt.id + '%s.pkl' % (append)),
                'wb') as f:
            utils.pickle_dump(infos, f)
        if histories:
            with open(
                    os.path.join(opt.checkpoint_path,
                                 'histories_' + opt.id + '%s.pkl' % (append)),
                    'wb') as f:
                utils.pickle_dump(histories, f)

    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                epoch_done = False

            start = time.time()
            if (opt.use_warmup == 1) and (iteration < opt.noamopt_warmup):
                opt.current_lr = opt.learning_rate * (iteration +
                                                      1) / opt.noamopt_warmup
                utils.set_lr(optimizer, opt.current_lr)
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            if (iteration % acc_steps == 0):
                optimizer.zero_grad()

            torch.cuda.synchronize()
            start = time.time()
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                    att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag)

            loss = model_out['loss'].mean()
            loss_sp = loss / acc_steps

            loss_sp.backward()
            if ((iteration + 1) % acc_steps == 0):
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
            torch.cuda.synchronize()
            train_loss = loss.item()
            end = time.time()
            if not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                                  iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                add_summary_value(tb_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tb_summary_writer, 'avg_reward',
                                      model_out['reward'].mean(), iteration)

                loss_history[
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                lr_history[iteration] = opt.current_lr
                ss_prob_history[iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['iterators'] = loader.iterators
            infos['split_ix'] = loader.split_ix

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {
                    'split': 'train',
                    'dataset': opt.input_json,
                    'num_images': 1
                }
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if lang_stats is not None and 'CIDEr' in lang_stats:
                        optimizer.scheduler_step(-lang_stats['CIDEr'])
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                add_summary_value(tb_summary_writer, 'validation loss',
                                  val_loss, iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        add_summary_value(tb_summary_writer, k, v, iteration)
                val_result_history[iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscalleous informations
                infos['best_val_score'] = best_val_score
                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history

                save_checkpoint(model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    save_checkpoint(model,
                                    infos,
                                    optimizer,
                                    append=str(iteration))

                if best_flag:
                    save_checkpoint(model, infos, optimizer, append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break
    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        save_checkpoint(model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
def train(sketch_dataloader, shape_dataloader, model, criterion, optimizer, epoch, opt):
    """
    train for one epoch on the training set
    """
    batch_time = utils.AverageMeter()
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()
    tpl_losses = utils.AverageMeter()

    # training mode
    net_whole, net_bp, net_vp, net_ap, net_cls = model
    optim_sketch, optim_shape, optim_centers = optimizer
    crt_cls, crt_tlc, w1, w2 = criterion

    net_whole.train()
    net_bp.train()
    net_vp.train()
    net_ap.train()
    net_cls.train()

    end = time.time()
    # debug_here() 
    for i, ((sketches, k_labels), (shapes, p_labels)) in enumerate(zip(sketch_dataloader, shape_dataloader)):

        shapes = shapes.view(shapes.size(0)*shapes.size(1), shapes.size(2), shapes.size(3), shapes.size(4))

        # expanding: (bz * 12) x 3 x 224 x 224
        shapes = shapes.expand(shapes.size(0), 3, shapes.size(2), shapes.size(3))

        shapes_v = Variable(shapes.cuda())
        p_labels_v = Variable(p_labels.long().cuda())

        sketches_v = Variable(sketches.cuda())
        k_labels_v = Variable(k_labels.long().cuda())


        o_bp = net_bp(shapes_v)
        o_vp = net_vp(o_bp)
        shape_feat = net_ap(o_vp)
        sketch_feat = net_whole(sketches_v)
        feat = torch.cat([shape_feat, sketch_feat])
        target = torch.cat([p_labels_v, k_labels_v])
        score = net_cls(feat) 
        
        cls_loss = crt_cls(score, target)
        tpl_loss, _ = crt_tlc(score, target)
        # tpl_loss, _ = crt_tlc(feat, target)

        loss = w1 * cls_loss + w2 * tpl_loss

        ## measure accuracy
        prec1 = utils.accuracy(score.data, target.data, topk=(1,))[0]
        losses.update(cls_loss.data[0], score.size(0)) # batchsize
        tpl_losses.update(tpl_loss.data[0], score.size(0))
        top1.update(prec1[0], score.size(0))

        ## backward
        optim_sketch.zero_grad()
        optim_shape.zero_grad()
        optim_centers.zero_grad()

        loss.backward()
        utils.clip_gradient(optim_sketch, opt.gradient_clip)
        utils.clip_gradient(optim_shape, opt.gradient_clip)
        utils.clip_gradient(optim_centers, opt.gradient_clip)
        
        optim_sketch.step()
        optim_shape.step()
        optim_centers.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                'Prec@1 {top1.val:.3f} ({top1.avg:.3f})\t'
                'Trploss {triplet.val:.4f}({triplet.avg:.3f})'.format(
                epoch, i, len(sketch_dataloader), batch_time=batch_time,
                loss=losses, top1=top1, triplet=tpl_losses))
            # print('triplet loss: ', tpl_center_loss.data[0])
    print(' * Train Prec@1 {top1.avg:.3f}'.format(top1=top1))
    return top1.avg
Exemplo n.º 20
0
def train(opt):
    opt.use_att = utils.if_use_att(opt)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
        best_val_score_vse = infos.get('best_val_score_vse', None)

    model = models.JointModel(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    optimizer = optim.Adam([p for p in model.parameters() if p.requires_grad],
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, 'optimizer.pth')):
        state_dict = torch.load(os.path.join(opt.start_from, 'optimizer.pth'))
        if len(state_dict['state']) == len(optimizer.state_dict()['state']):
            optimizer.load_state_dict(state_dict)
        else:
            print(
                'Optimizer param group number not matched? There must be new parameters. Reinit the optimizer.'
            )

    init_scorer(opt.cached_tokens)
    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.caption_generator.ss_prob = opt.ss_prob
            # Assign retrieval loss weight
            if epoch > opt.retrieval_reward_weight_decay_start and opt.retrieval_reward_weight_decay_start >= 0:
                frac = (epoch - opt.retrieval_reward_weight_decay_start
                        ) // opt.retrieval_reward_weight_decay_every
                model.retrieval_reward_weight = opt.retrieval_reward_weight * (
                    opt.retrieval_reward_weight_decay_rate**frac)
            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['att_masks'],
            data['labels'], data['masks']
        ]
        tmp = utils.var_wrapper(tmp)
        fc_feats, att_feats, att_masks, labels, masks = tmp

        optimizer.zero_grad()

        loss = model(fc_feats, att_feats, att_masks, labels, masks, data)
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))
        prt_str = ""
        for k, v in model.loss().items():
            prt_str += "{} = {:.3f} ".format(k, v)
        print(prt_str)

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                tf_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                for k, v in model.loss().items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tf_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.caption_generator.ss_prob,
                                             iteration)
                tf_summary_writer.add_scalar('retrieval_reward_weight',
                                             model.retrieval_reward_weight,
                                             iteration)
                tf_summary_writer.file_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.caption_generator.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # Load the retrieval model for evaluation
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                model, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                for k, v in val_loss.items():
                    tf_summary_writer.add_scalar('validation ' + k, v,
                                                 iteration)
                for k, v in lang_stats.items():
                    tf_summary_writer.add_scalar(k, v, iteration)
                tf_summary_writer.add_text(
                    'Captions',
                    '.\n\n'.join([_['caption'] for _ in predictions[:100]]),
                    iteration)
                #tf_summary_writer.add_image('images', utils.make_summary_image(), iteration)
                #utils.make_html(opt.id, iteration)
                tf_summary_writer.file_writer.flush()

            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['SPICE'] * 100
            else:
                current_score = -val_loss['loss_cap']
            current_score_vse = val_loss.get(opt.vse_eval_criterion, 0) * 100

            best_flag = False
            best_flag_vse = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if best_val_score_vse is None or current_score_vse > best_val_score_vse:
                    best_val_score_vse = current_score_vse
                    best_flag_vse = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model-%d.pth' % (iteration))
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['best_val_score_vse'] = best_val_score_vse
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(
                            opt.checkpoint_path,
                            'infos_' + opt.id + '-%d.pkl' % (iteration)),
                        'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)
                if best_flag_vse:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model_vse-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_vse_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 21
0
def train(opt):
    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(log_dir=opt.checkpoint_path)
    print(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        print(os.getcwd())
        with open(
                os.path.join(os.getcwd(), opt.start_from,
                             'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    dp_model = model

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        # # [added] reproduce straight line learning rate decay in supplementary
        # #      ---- the original paper used 60k iters
        # #      ---- if lr goes to zero just stay at the last lr
        # linear_lr = -(iteration+1)*opt.learning_rate/60000 + opt.learning_rate
        # if linear_lr <= 0:
        #     pass
        # else:
        #     opt.current_lr = linear_lr
        #     utils.set_lr(optimizer, opt.current_lr)

        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)

            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        # [naxin] knn_data is the nearest neighbour batch, the format is identical to data
        data, knn_data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        # tmp = [knn_data['fc_feats'], knn_data['att_feats'], knn_data['labels'], knn_data['masks'], knn_data['att_masks']]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs, eval_knn=opt.use_knn)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 22
0
def train(loader, model, crit, optimizer, lr_scheduler, opt, rl_crit=None):
    model.train()
    if torch.cuda.device_count() > 1:
        print("{} devices detected, switch to parallel model.".format(
            torch.cuda.device_count()))
        model = nn.DataParallel(model)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model.to(device)

    for epoch in range(opt["epochs"]):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        if opt["self_crit_after"] != -1 and epoch >= opt["self_crit_after"]:
            sc_flag = True
            init_cider_scorer(opt["cached_tokens"])
        else:
            sc_flag = False

        for data in loader:
            torch.cuda.synchronize()
            fc_feats = data['fc_feats'].to(device)
            labels = data['labels'].to(device)
            masks = data['masks'].to(device)

            if not sc_flag:
                seq_probs, _ = model(fc_feats, labels, 'train')
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                seq_probs, seq_preds = model(fc_feats,
                                             mode='inference',
                                             opt=opt)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  seq_preds)
                print(reward.shape)
                loss = rl_crit(
                    seq_probs, seq_preds,
                    Variable(torch.from_numpy(reward).float().cuda()))

            optimizer.zero_grad()
            loss.backward()
            utils.clip_gradient(optimizer, opt["grad_clip"])
            optimizer.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.6f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        if epoch != 0 and epoch % opt["save_checkpoint_every"] == 0:
            model_path = os.path.join(opt["checkpoint_path"],
                                      'model_%d.pth' % (epoch))
            model_info_path = os.path.join(opt["checkpoint_path"],
                                           'model_score.txt')
            torch.save(model.state_dict(), model_path)
            print("model saved to %s" % (model_path))
            with open(model_info_path, 'a') as f:
                f.write("model_%d, loss: %.6f\n" % (epoch, train_loss))
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from,
                               'infos_' + opt.id + '.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size1", "rnn_size2",
                "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    # loader.iterators = infos.get('iterators', loader.iterators)
    # loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()
    # model.set_mode('train')

    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = optim.Adam(model.parameters(),
                           lr=opt.learning_rate,
                           weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        model.train()
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_cider_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train+val')
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['num_bbox'],
            data['labels'], data['masks']
        ]
        tmp = [
            Variable(torch.from_numpy(_).float(), requires_grad=False).cuda()
            for _ in tmp
        ]
        fc_feats, att_feats, num_bbox, labels, masks = tmp
        labels = labels.long()

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(model(fc_feats, att_feats, num_bbox, labels),
                        labels[:, 1:], masks[:, 1:])
            # loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        else:
            gen_result, sample_logprobs = model.sample(fc_feats, att_feats,
                                                       num_bbox,
                                                       {'sample_max': 0})
            reward = get_self_critical_reward(model, fc_feats, att_feats,
                                              num_bbox, data, gen_result)
            loss = rl_crit(
                sample_logprobs, gen_result,
                Variable(torch.from_numpy(reward).float().cuda(),
                         requires_grad=False))

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            if (iteration % 100 == 0):
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f} lr={}" \
                 .format(iteration, epoch, train_loss, end - start, opt.current_lr ))
        else:
            if (iteration % 100 == 0):
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f} lr={}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start, opt.current_lr ))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss,
                                  iteration)
                add_summary_value(tf_summary_writer, 'learning_rate',
                                  opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob',
                                  model.ss_prob, iteration)
                if sc_flag:
                    add_summary_value(tf_summary_writer, 'avg_reward',
                                      np.mean(reward[:, 0]), iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'val_ref_path': opt.val_ref_path,
                'raw_val_anno_path': opt.raw_val_anno_path
            }
            eval_kwargs.update(vars(opt))
            # predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            best_flag = False
            if True:  # if true
                # if best_val_score is None or current_score > best_val_score:
                # 	best_val_score = current_score
                # 	best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 24
0
            if torch.cuda.is_available():
                optimizer.zero_grad()
                dt = {key: _.cuda() if isinstance(_, torch.Tensor) else _ for key, _ in dt.items()}  # 这写法有点秀

            dt = collections.defaultdict(lambda: None, dt)

            if True:
                train_mode = 'train_rl' if sc_flag else 'train'   # train_rl是以强化学习的方式进行训练

                loss, sample_score, greedy_score = model(dt, mode=train_mode, loader=train_loader)
                loss_sum[0] = loss_sum[0] + loss.item()  # store loss
                loss_sum[1] = loss_sum[1] + sample_score.mean().item()    # store sample_score
                loss_sum[2] = loss_sum[2] + greedy_score.mean().item()    # store greedy_score

                loss.backward()
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                if torch.cuda.is_available():
                    torch.cuda.synchronize()

            losses_log_every = int(len(train_loader) / 5)

            ######## 记录loss log #########
            if iteration % losses_log_every == 0:
                end = time.time()
                losses = np.round(loss_sum / losses_log_every, 3)
                logger.info(
                    "ID {} iter {} (epoch {}, lr {}), avg_iter_loss = {}, time/iter = {:.3f}, bad_vid = {:.3f}"
                        .format(opt.id, iteration, epoch, opt.current_lr, losses,
                                (end - start) / losses_log_every, bad_video_num))
Exemplo n.º 25
0
def train(opt):
    set_seed(opt.seed)
    save_folder = build_floder(opt)
    logger = create_logger(save_folder, 'train.log')
    tf_writer = SummaryWriter(os.path.join(save_folder, 'tf_summary'))

    if not opt.start_from:
        backup_envir(save_folder)
        logger.info('backup evironment completed !')

    saved_info = {'best': {}, 'last': {}, 'history': {}, 'eval_history': {}}

    # continue training
    if opt.start_from:
        opt.pretrain = False
        infos_path = os.path.join(save_folder, 'info.json')
        with open(infos_path) as f:
            logger.info('Load info from {}'.format(infos_path))
            saved_info = json.load(f)
            prev_opt = saved_info[opt.start_from_mode[:4]]['opt']

            exclude_opt = ['start_from', 'start_from_mode', 'pretrain']
            for opt_name in prev_opt.keys():
                if opt_name not in exclude_opt:
                    vars(opt).update({opt_name: prev_opt.get(opt_name)})
                if prev_opt.get(opt_name) != vars(opt).get(opt_name):
                    logger.info('Change opt {} : {} --> {}'.format(opt_name, prev_opt.get(opt_name),
                                                                   vars(opt).get(opt_name)))
        opt.feature_dim = opt.raw_feature_dim

    train_dataset = PropSeqDataset(opt.train_caption_file,
                                   opt.visual_feature_folder,
                                   opt.dict_file, True, opt.train_proposal_type,
                                   logger, opt)

    val_dataset = PropSeqDataset(opt.val_caption_file,
                                 opt.visual_feature_folder,
                                 opt.dict_file, False, 'gt',
                                 logger, opt)

    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size,
                              shuffle=True, num_workers=opt.nthreads, collate_fn=collate_fn)

    val_loader = DataLoader(val_dataset, batch_size=opt.batch_size,
                            shuffle=False, num_workers=opt.nthreads, collate_fn=collate_fn)

    epoch = saved_info[opt.start_from_mode[:4]].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode[:4]].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode[:4]].get('best_val_score', -1e5)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    opt.current_lr = vars(opt).get('current_lr', opt.lr)
    opt.vocab_size = train_loader.dataset.vocab_size

    # Build model
    model = EncoderDecoder(opt)
    model.train()

    # Recover the parameters
    if opt.start_from and (not opt.pretrain):
        if opt.start_from_mode == 'best':
            model_pth = torch.load(os.path.join(save_folder, 'model-best-CE.pth'))
        elif opt.start_from_mode == 'best-RL':
            model_pth = torch.load(os.path.join(save_folder, 'model-best-RL.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(os.path.join(save_folder, 'model-last.pth'))
        logger.info('Loading pth from {}, iteration:{}'.format(save_folder, iteration))
        model.load_state_dict(model_pth['model'])

    # Load the pre-trained model
    if opt.pretrain and (not opt.start_from):
        logger.info('Load pre-trained parameters from {}'.format(opt.pretrain_path))
        if torch.cuda.is_available():
            model_pth = torch.load(opt.pretrain_path)
        else:
            model_pth = torch.load(opt.pretrain_path, map_location=torch.device('cpu'))
        model.load_state_dict(model_pth['model'])

    if torch.cuda.is_available():
        model.cuda()

    if opt.optimizer_type == 'adam':
        optimizer = optim.Adam(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)
    else:
        optimizer = optim.SGD(model.parameters(), lr=opt.lr, weight_decay=opt.weight_decay)

    if opt.start_from:
        optimizer.load_state_dict(model_pth['optimizer'])

    # print the args for debugging
    print_opt(opt, model, logger)
    print_alert_message('Strat training !', logger)

    loss_sum = np.zeros(3)
    bad_video_num = 0
    start = time.time()

    # Epoch-level iteration
    while True:
        if True:
            # lr decay
            if epoch > opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate ** frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            utils.set_lr(optimizer, opt.current_lr)

            # scheduled sampling rate update
            if epoch > opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.basic_ss_prob + opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.caption_decoder.ss_prob = opt.ss_prob

            # self critical learning flag
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer()
                model.caption_decoder.ss_prob = 0
            else:
                sc_flag = False

        # Batch-level iteration
        for dt in tqdm(train_loader):
            if torch.cuda.is_available():
                torch.cuda.synchronize()
            if opt.debug:
                # each epoch contains less mini-batches for debugging
                if (iteration + 1) % 5 == 0:
                    iteration += 1
                    break
            elif epoch == 0:
                break
            iteration += 1

            if torch.cuda.is_available():
                optimizer.zero_grad()
                dt = {key: _.cuda() if isinstance(_, torch.Tensor) else _ for key, _ in dt.items()}

            dt = collections.defaultdict(lambda: None, dt)

            if True:
                train_mode = 'train_rl' if sc_flag else 'train'

                loss, sample_score, greedy_score = model(dt, mode=train_mode, loader=train_loader)
                loss_sum[0] = loss_sum[0] + loss.item()
                loss_sum[1] = loss_sum[1] + sample_score.mean().item()
                loss_sum[2] = loss_sum[2] + greedy_score.mean().item()

                loss.backward()
                utils.clip_gradient(optimizer, opt.grad_clip)
                optimizer.step()
                if torch.cuda.is_available():
                    torch.cuda.synchronize()

            losses_log_every = int(len(train_loader) / 5)

            if iteration % losses_log_every == 0:
                end = time.time()
                losses = np.round(loss_sum / losses_log_every, 3)
                logger.info(
                    "ID {} iter {} (epoch {}, lr {}), avg_iter_loss = {}, time/iter = {:.3f}, bad_vid = {:.3f}"
                        .format(opt.id, iteration, epoch, opt.current_lr, losses,
                                (end - start) / losses_log_every, bad_video_num))

                tf_writer.add_scalar('lr', opt.current_lr, iteration)
                tf_writer.add_scalar('ss_prob', model.caption_decoder.ss_prob, iteration)
                tf_writer.add_scalar('train_caption_loss', losses[0].item(), iteration)
                tf_writer.add_scalar('train_rl_sample_score', losses[1].item(), iteration)
                tf_writer.add_scalar('train_rl_greedy_score', losses[2].item(), iteration)

                loss_history[iteration] = losses.tolist()
                lr_history[iteration] = opt.current_lr
                loss_sum = 0 * loss_sum
                start = time.time()
                bad_video_num = 0
                torch.cuda.empty_cache()

        # evaluation
        if (epoch % opt.save_checkpoint_every == 0) and (epoch >= opt.min_epoch_when_save) and (epoch != 0):
            model.eval()

            dvc_json_path = os.path.join(save_folder, 'prediction',
                                         'num{}_epoch{}_score{}_nms{}_top{}.json'.format(
                                             len(val_dataset), epoch, opt.eval_score_threshold,
                                             opt.eval_nms_threshold, opt.eval_top_n))
            eval_score = evaluate(model, val_loader, dvc_json_path, './data/captiondata/val_1_for_tap.json',
                                  opt.eval_score_threshold, opt.eval_nms_threshold,
                                  opt.eval_top_n, logger=logger)
            current_score = np.array(eval_score['METEOR']).mean()

            # add to tf summary
            for key in eval_score.keys():
                tf_writer.add_scalar(key, np.array(eval_score[key]).mean(), iteration)
            _ = [item.append(np.array(item).mean()) for item in eval_score.values() if isinstance(item, list)]
            print_info = '\n'.join([key + ":" + str(eval_score[key]) for key in eval_score.keys()])
            logger.info('\nValidation results of iter {}:\n'.format(iteration) + print_info)

            # for name, param in model.named_parameters():
            #     tf_writer.add_histogram(name, param.clone().cpu().data.numpy(), iteration, bins=10)
            #     if param.grad is not None:
            #         tf_writer.add_histogram(name + '_grad', param.grad.clone().cpu().data.numpy(), iteration,
            #                                 bins=10)

            val_result_history[epoch] = {'eval_score': eval_score}

            # Save model
            saved_pth = {'epoch': epoch,
                         'model': model.state_dict(),
                         'optimizer': optimizer.state_dict(), }

            if opt.save_all_checkpoint:
                checkpoint_path = os.path.join(save_folder, 'model_iter_{}.pth'.format(iteration))
            else:
                checkpoint_path = os.path.join(save_folder, 'model-last.pth')

            torch.save(saved_pth, checkpoint_path)
            logger.info('Save model at iter {} to {}.'.format(iteration, checkpoint_path))

            # save the model parameter and  of best epoch
            if current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                saved_info['best'] = {'opt': vars(opt),
                                      'iter': iteration,
                                      'epoch': best_epoch,
                                      'best_val_score': best_val_score,
                                      'dvc_json_path': dvc_json_path,
                                      'METEOR': eval_score['METEOR'],
                                      'avg_proposal_num': eval_score['avg_proposal_number'],
                                      'Precision': eval_score['Precision'],
                                      'Recall': eval_score['Recall']
                                      }

                suffix = "RL" if sc_flag else "CE"
                torch.save(saved_pth, os.path.join(save_folder, 'model-best-{}.pth'.format(suffix)))
                logger.info('Save Best-model at iter {} to checkpoint file.'.format(iteration))

            saved_info['last'] = {'opt': vars(opt),
                                  'iter': iteration,
                                  'epoch': epoch,

                                  'best_val_score': best_val_score,
                                  }
            saved_info['history'] = {'val_result_history': val_result_history,
                                     'loss_history': loss_history,
                                     'lr_history': lr_history,
                                     }
            with open(os.path.join(save_folder, 'info.json'), 'w') as f:
                json.dump(saved_info, f)
            logger.info('Save info to info.json')

            model.train()

        epoch += 1
        torch.cuda.empty_cache()
        # Stop criterion
        if epoch >= opt.epoch:
            tf_writer.close()
            break

    return saved_info
Exemplo n.º 26
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    # log information
    folder_id='log_result'
    file_id='show_tell'
    log_file_name=os.path.join(folder_id,file_id + '.txt')
    log_file=open(log_file_name,'w')

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()

    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion() # define the loss criterion

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        
        optimizer.zero_grad()

        loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:]) # compute using the defined criterion
        
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        
        # store the relevant values
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        print((time.time(),time.clock()))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            if epoch == 25:
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model_25_epochs_512_batch.pth')
                torch.save(model.state_dict(), checkpoint_path)
            elif epoch == 12:
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model_12_epochs_64_batch.pth')
                torch.save(model.state_dict(), checkpoint_path)
            elif epoch == 75:
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model_75_epochs_512_batch.pth')
                torch.save(model.state_dict(), checkpoint_path)
            elif epoch == 100:
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model_100_epochs_512_batch.pth')
                torch.save(model.state_dict(), checkpoint_path)
            
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob
            log_line = 'Epoch [%d], Step [%d], loss: %f, time %f' % (
                    epoch,iteration,
                    train_loss,time.clock()
                    )
            log_file.write(log_line + '\n')

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 27
0
def train(opt):
    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tf_summary_writer = tf and tf.summary.FileWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_'+opt.id+'.pkl'),'rb') as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same=["caption_model", "rnn_type", "rnn_size", "num_layers"]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(opt)[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl')):
            with open(os.path.join(opt.start_from, 'histories_'+opt.id+'.pkl'),'rb') as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt)
    model.cuda()
    if opt.multi_gpu:
        model=nn.DataParallel(model)
    update_lr_flag = True
    # Assure in training mode
    model.train()

    crit = utils.LanguageModelCriterion()

    optimizer = optim.Adam(model.parameters(), lr=opt.learning_rate, weight_decay=opt.weight_decay)

    # Load the optimizer
    if vars(opt).get('start_from', None) is not None:
        optimizer.load_state_dict(torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    while True:
        if update_lr_flag:
                # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate  ** frac
                opt.current_lr = opt.learning_rate * decay_factor
                utils.set_lr(optimizer, opt.current_lr) # set the decayed rate
            else:
                opt.current_lr = opt.learning_rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob  * frac, opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob
            update_lr_flag = False
                
        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [data['fc_feats'], data['att_feats'], data['labels'], data['masks']]
        tmp = [Variable(torch.from_numpy(_), requires_grad=False).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks = tmp
        
        optimizer.zero_grad()
        loss = crit(model(fc_feats, att_feats, labels), labels[:,1:], masks[:,1:])
        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.data[0]
        torch.cuda.synchronize()
        end = time.time()
        print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
            .format(iteration, epoch, train_loss, end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            if tf is not None:
                add_summary_value(tf_summary_writer, 'train_loss', train_loss, iteration)
                add_summary_value(tf_summary_writer, 'learning_rate', opt.current_lr, iteration)
                add_summary_value(tf_summary_writer, 'scheduled_sampling_prob', model.ss_prob, iteration)
                tf_summary_writer.flush()

            loss_history[iteration] = train_loss
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        if (iteration % opt.save_checkpoint_every == 0):
            # eval model
            eval_kwargs = {'split': 'val',
                            'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(model, crit, loader, eval_kwargs)

            # Write validation result into summary
            if tf is not None:
                add_summary_value(tf_summary_writer, 'validation loss', val_loss, iteration)
                for k,v in lang_stats.items():
                    add_summary_value(tf_summary_writer, k, v, iteration)
                tf_summary_writer.flush()
            val_result_history[iteration] = {'loss': val_loss, 'lang_stats': lang_stats, 'predictions': predictions}

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = - val_loss

            best_flag = False
            if True: # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path, 'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path, 'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories_'+opt.id+'.pkl'), 'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path, 'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(os.path.join(opt.checkpoint_path, 'infos_'+opt.id+'-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 28
0
def train(opt):
    # Deal with feature things before anything
    opt.use_fc, opt.use_att = utils.if_use_feat(opt.caption_model)
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    tb_summary_writer = tb and tb.SummaryWriter(opt.checkpoint_path)

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl'),
                  'rb') as f:
            infos = utils.pickle_load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(
                os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
            with open(
                    os.path.join(opt.start_from,
                                 'histories_' + opt.id + '.pkl'), 'rb') as f:
                histories = utils.pickle_load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    dp_model = torch.nn.DataParallel(model)

    epoch_done = True
    # Assure in training mode
    dp_model.train()

    if opt.label_smoothing > 0:
        crit = utils.LabelSmoothing(smoothing=opt.label_smoothing)
    else:
        crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
        optimizer._step = iteration
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if vars(opt).get('start_from', None) is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    total_loss = 0
    times = 0
    while True:
        if epoch_done:
            if not opt.noamopt and not opt.reduce_on_plateau:
                # Assign the learning rate
                if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                    frac = (epoch - opt.learning_rate_decay_start
                            ) // opt.learning_rate_decay_every
                    decay_factor = opt.learning_rate_decay_rate**frac
                    opt.current_lr = opt.learning_rate * decay_factor
                else:
                    opt.current_lr = opt.learning_rate
                utils.set_lr(optimizer, opt.current_lr)  # set the decayed rate
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            epoch_done = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch('train')
        print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        tmp = [
            data['fc_feats'], data['att_feats'], data['labels'], data['masks'],
            data['att_masks']
        ]
        tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
        fc_feats, att_feats, labels, masks, att_masks = tmp

        times += 1

        optimizer.zero_grad()
        if not sc_flag:
            loss = crit(dp_model(fc_feats, att_feats, labels, att_masks),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   att_feats,
                                                   att_masks,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, att_feats,
                                              att_masks, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        loss.backward()
        utils.clip_gradient(optimizer, opt.grad_clip)
        optimizer.step()
        train_loss = loss.item()
        total_loss = total_loss + train_loss
        torch.cuda.synchronize()
        end = time.time()
        if not sc_flag:
            print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, train_loss, end - start))
        else:
            print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                .format(iteration, epoch, np.mean(reward[:,0]), end - start))

        # Update the iteration and epoch
        iteration += 1
        if data['bounds']['wrapped']:
            # epoch += 1
            epoch_done = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            if opt.noamopt:
                opt.current_lr = optimizer.rate()
            elif opt.reduce_on_plateau:
                opt.current_lr = optimizer.current_lr
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration % opt.save_checkpoint_every == 0):
        if data['bounds']['wrapped']:
            epoch += 1
            # eval model
            eval_kwargs = {
                'split': 'val',
                'dataset': opt.input_json,
                'verbose': False
            }
            eval_kwargs.update(vars(opt))
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            if opt.reduce_on_plateau:
                if 'CIDEr' in lang_stats:
                    optimizer.scheduler_step(-lang_stats['CIDEr'])
                else:
                    optimizer.scheduler_step(val_loss)
            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats
                f = open('train_log_%s.txt' % opt.id, 'a')
                f.write(
                    'Epoch {}: | Date: {} | TrainLoss: {} | ValLoss: {} | Score: {}'
                    .format(epoch, str(datetime.now()),
                            str(total_loss / times), str(val_loss),
                            str(current_score)))
                f.write('\n')
                f.close()
                print('-------------------wrote to log file')
                total_loss = 0
                times = 0
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                if not os.path.isdir(opt.checkpoint_path):
                    os.mkdir(opt.checkpoint_path)
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                # print(str(infos['best_val_score']))
                print("model saved to {}".format(checkpoint_path))
                if opt.save_history_ckpt:
                    checkpoint_path = os.path.join(
                        opt.checkpoint_path, 'model-%d.pth' % (iteration))
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'infos_' + opt.id + '.pkl'), 'wb') as f:
                    utils.pickle_dump(infos, f)
                if opt.save_history_ckpt:
                    with open(
                            os.path.join(
                                opt.checkpoint_path,
                                'infos_' + opt.id + '-%d.pkl' % (iteration)),
                            'wb') as f:
                        cPickle.dump(infos, f)
                with open(
                        os.path.join(opt.checkpoint_path,
                                     'histories_' + opt.id + '.pkl'),
                        'wb') as f:
                    utils.pickle_dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos_' + opt.id + '-best.pkl'),
                            'wb') as f:
                        utils.pickle_dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 29
0
def train(train_loader, model, criterion, optimizer, epoch, opt):
    """
    train for one epoch on the training set
    """
    batch_time = utils.AverageMeter()
    losses = utils.AverageMeter()
    top1 = utils.AverageMeter()

    # training mode
    model.train()

    end = time.time()
    for i, (input_points, labels) in enumerate(train_loader):
        # bz x 2048 x 3
        input_points = Variable(input_points)
        input_points = input_points.transpose(2, 1)
        labels = Variable(labels[:, 0])

        # print(points.size())
        # print(labels.size())
        # shift data to GPU
        if opt.cuda:
            input_points = input_points.cuda()
            labels = labels.long().cuda()  # must be long cuda tensor

        # forward, backward optimize
        output, _ = model(input_points)
        # debug_here()
        loss = criterion(output, labels)
        ##############################
        # measure accuracy
        ##############################
        prec1 = utils.accuracy(output.data, labels.data, topk=(1, ))[0]
        losses.update(loss.data[0], input_points.size(0))
        top1.update(prec1[0], input_points.size(0))

        ##############################
        # compute gradient and do sgd
        ##############################
        optimizer.zero_grad()
        loss.backward()
        ##############################
        # gradient clip stuff
        ##############################
        utils.clip_gradient(optimizer, opt.gradient_clip)

        optimizer.step()

        # measure elapsed time
        batch_time.update(time.time() - end)
        end = time.time()
        if i % opt.print_freq == 0:
            print('Epoch: [{0}][{1}/{2}]\t'
                  'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t'
                  'Loss {loss.val:.4f} ({loss.avg:.4f})\t'
                  'Prec@1 {top1.val:.3f} ({top1.avg:.3f})'.format(
                      epoch,
                      i,
                      len(train_loader),
                      batch_time=batch_time,
                      loss=losses,
                      top1=top1))
Exemplo n.º 30
0
Arquivo: train.py Projeto: cxqj/ECHR
def train(opt):
    exclude_opt = [
        'training_mode', 'tap_epochs', 'cg_epochs', 'tapcg_epochs', 'lr',
        'learning_rate_decay_start', 'learning_rate_decay_every',
        'learning_rate_decay_rate', 'self_critical_after',
        'save_checkpoint_every', 'id', "pretrain", "pretrain_path", "debug",
        "save_all_checkpoint", "min_epoch_when_save"
    ]

    save_folder, logger, tf_writer = build_floder_and_create_logger(opt)
    saved_info = {'best': {}, 'last': {}, 'history': {}}
    is_continue = opt.start_from != None

    if is_continue:
        infos_path = os.path.join(save_folder, 'info.pkl')
        with open(infos_path) as f:
            logger.info('load info from {}'.format(infos_path))
            saved_info = cPickle.load(f)
            pre_opt = saved_info[opt.start_from_mode]['opt']
            if vars(opt).get("no_exclude_opt", False):
                exclude_opt = []
            for opt_name in vars(pre_opt).keys():
                if (not opt_name in exclude_opt):
                    vars(opt).update({opt_name: vars(pre_opt).get(opt_name)})
                if vars(pre_opt).get(opt_name) != vars(opt).get(opt_name):
                    print('change opt: {} from {} to {}'.format(
                        opt_name,
                        vars(pre_opt).get(opt_name),
                        vars(opt).get(opt_name)))

    opt.use_att = utils.if_use_att(opt.caption_model)
    loader = DataLoader(opt)
    opt.CG_vocab_size = loader.vocab_size
    opt.CG_seq_length = loader.seq_length

    # init training option
    epoch = saved_info[opt.start_from_mode].get('epoch', 0)
    iteration = saved_info[opt.start_from_mode].get('iter', 0)
    best_val_score = saved_info[opt.start_from_mode].get('best_val_score', 0)
    val_result_history = saved_info['history'].get('val_result_history', {})
    loss_history = saved_info['history'].get('loss_history', {})
    lr_history = saved_info['history'].get('lr_history', {})
    loader.iterators = saved_info[opt.start_from_mode].get(
        'iterators', loader.iterators)
    loader.split_ix = saved_info[opt.start_from_mode].get(
        'split_ix', loader.split_ix)
    opt.current_lr = vars(opt).get('current_lr', opt.lr)
    opt.m_batch = vars(opt).get('m_batch', 1)

    # create a tap_model,fusion_model,cg_model

    tap_model = models.setup_tap(opt)
    lm_model = CaptionGenerator(opt)
    cg_model = lm_model

    if is_continue:
        if opt.start_from_mode == 'best':
            model_pth = torch.load(os.path.join(save_folder, 'model-best.pth'))
        elif opt.start_from_mode == 'last':
            model_pth = torch.load(
                os.path.join(save_folder,
                             'model_iter_{}.pth'.format(iteration)))
        assert model_pth['iteration'] == iteration
        logger.info('Loading pth from {}, iteration:{}'.format(
            save_folder, iteration))
        tap_model.load_state_dict(model_pth['tap_model'])
        cg_model.load_state_dict(model_pth['cg_model'])

    elif opt.pretrain:
        print('pretrain {} from {}'.format(opt.pretrain, opt.pretrain_path))
        model_pth = torch.load(opt.pretrain_path)
        if opt.pretrain == 'tap':
            tap_model.load_state_dict(model_pth['tap_model'])
        elif opt.pretrain == 'cg':
            cg_model.load_state_dict(model_pth['cg_model'])
        elif opt.pretrain == 'tap_cg':
            tap_model.load_state_dict(model_pth['tap_model'])
            cg_model.load_state_dict(model_pth['cg_model'])
        else:
            assert 1 == 0, 'opt.pretrain error'

    tap_model.cuda()
    tap_model.train()  # Assure in training mode

    tap_crit = utils.TAPModelCriterion()

    tap_optimizer = optim.Adam(tap_model.parameters(),
                               lr=opt.lr,
                               weight_decay=opt.weight_decay)

    cg_model.cuda()
    cg_model.train()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)
    cg_crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()
    cg_optimizer = optim.Adam(cg_model.parameters(),
                              lr=opt.lr,
                              weight_decay=opt.weight_decay)

    allmodels = [tap_model, cg_model]
    optimizers = [tap_optimizer, cg_optimizer]

    if is_continue:
        tap_optimizer.load_state_dict(model_pth['tap_optimizer'])
        cg_optimizer.load_state_dict(model_pth['cg_optimizer'])

    update_lr_flag = True
    loss_sum = np.zeros(5)
    bad_video_num = 0
    best_epoch = epoch
    start = time.time()

    print_opt(opt, allmodels, logger)
    logger.info('\nStart training')

    # set a var to indicate what to train in current iteration: "tap", "cg" or "tap_cg"
    flag_training_whats = get_training_list(opt, logger)

    # Iteration begin
    while True:
        if update_lr_flag:
            if (epoch > opt.learning_rate_decay_start
                    and opt.learning_rate_decay_start >= 0):
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.lr * decay_factor
            else:
                opt.current_lr = opt.lr
            for optimizer in optimizers:
                utils.set_lr(optimizer, opt.current_lr)
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(None)
            else:
                sc_flag = False
            update_lr_flag = False

        flag_training_what = flag_training_whats[epoch]
        if opt.training_mode == "alter2":
            flag_training_what = flag_training_whats[iteration]

        # get data
        data = loader.get_batch('train')

        if opt.debug:
            print('vid:', data['vid'])
            print('info:', data['infos'])

        torch.cuda.synchronize()

        if (data["proposal_num"] <= 0) or (data['fc_feats'].shape[0] <= 1):
            bad_video_num += 1  # print('vid:{} has no good proposal.'.format(data['vid']))
            continue

        ind_select_list, soi_select_list, cg_select_list, sampled_ids, = data[
            'ind_select_list'], data['soi_select_list'], data[
                'cg_select_list'], data['sampled_ids']

        if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg':
            ind_select_list = data['gts_ind_select_list']
            soi_select_list = data['gts_soi_select_list']
            cg_select_list = data['gts_cg_select_list']

        tmp = [
            data['fc_feats'], data['att_feats'], data['lda_feats'],
            data['tap_labels'], data['tap_masks_for_loss'],
            data['cg_labels'][cg_select_list],
            data['cg_masks'][cg_select_list], data['w1']
        ]

        tmp = [
            Variable(torch.from_numpy(_), requires_grad=False).cuda()
            for _ in tmp
        ]

        c3d_feats, att_feats, lda_feats, tap_labels, tap_masks_for_loss, cg_labels, cg_masks, w1 = tmp

        if (iteration - 1) % opt.m_batch == 0:
            tap_optimizer.zero_grad()
            cg_optimizer.zero_grad()

        tap_feats, pred_proposals = tap_model(c3d_feats)
        tap_loss = tap_crit(pred_proposals, tap_masks_for_loss, tap_labels, w1)

        loss_sum[0] = loss_sum[0] + tap_loss.item()

        # Backward Propagation
        if flag_training_what == 'tap':
            tap_loss.backward()
            utils.clip_gradient(tap_optimizer, opt.grad_clip)
            if iteration % opt.m_batch == 0:
                tap_optimizer.step()
        else:
            if not sc_flag:
                pred_captions = cg_model(tap_feats,
                                         c3d_feats,
                                         lda_feats,
                                         cg_labels,
                                         ind_select_list,
                                         soi_select_list,
                                         mode='train')
                cg_loss = cg_crit(pred_captions, cg_labels[:, 1:],
                                  cg_masks[:, 1:])

            else:
                gen_result, sample_logprobs, greedy_res = cg_model(
                    tap_feats,
                    c3d_feats,
                    lda_feats,
                    cg_labels,
                    ind_select_list,
                    soi_select_list,
                    mode='train_rl')
                sentence_info = data['sentences_batch'] if (
                    flag_training_what != 'cg'
                    and flag_training_what != 'gt_tap_cg'
                ) else data['gts_sentences_batch']

                reward = get_self_critical_reward2(
                    greedy_res, (data['vid'], sentence_info),
                    gen_result,
                    vocab=loader.get_vocab(),
                    opt=opt)
                cg_loss = rl_crit(sample_logprobs, gen_result,
                                  torch.from_numpy(reward).float().cuda())

            loss_sum[1] = loss_sum[1] + cg_loss.item()

            if flag_training_what == 'cg' or flag_training_what == 'gt_tap_cg' or flag_training_what == 'LP_cg':
                cg_loss.backward()

                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    cg_optimizer.step()
                if flag_training_what == 'gt_tap_cg':
                    utils.clip_gradient(tap_optimizer, opt.grad_clip)
                    if iteration % opt.m_batch == 0:
                        tap_optimizer.step()
            elif flag_training_what == 'tap_cg':
                total_loss = opt.lambda1 * tap_loss + opt.lambda2 * cg_loss
                total_loss.backward()
                utils.clip_gradient(tap_optimizer, opt.grad_clip)
                utils.clip_gradient(cg_optimizer, opt.grad_clip)
                if iteration % opt.m_batch == 0:
                    tap_optimizer.step()
                    cg_optimizer.step()

                loss_sum[2] = loss_sum[2] + total_loss.item()

        torch.cuda.synchronize()

        # Updating epoch num
        iteration += 1
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Print losses, Add to summary
        if iteration % opt.losses_log_every == 0:
            end = time.time()
            losses = np.round(loss_sum / opt.losses_log_every, 3)
            logger.info(
                "iter {} (epoch {}, lr {}), avg_iter_loss({}) = {}, time/batch = {:.3f}, bad_vid = {:.3f}" \
                    .format(iteration, epoch, opt.current_lr, flag_training_what, losses,
                            (end - start) / opt.losses_log_every,
                            bad_video_num))

            tf_writer.add_scalar('lr', opt.current_lr, iteration)
            tf_writer.add_scalar('train_tap_loss', losses[0], iteration)
            tf_writer.add_scalar('train_tap_prop_loss', losses[3], iteration)
            tf_writer.add_scalar('train_tap_bound_loss', losses[4], iteration)
            tf_writer.add_scalar('train_cg_loss', losses[1], iteration)
            tf_writer.add_scalar('train_total_loss', losses[2], iteration)
            if sc_flag and (not flag_training_what == 'tap'):
                tf_writer.add_scalar('avg_reward', np.mean(reward[:, 0]),
                                     iteration)
            loss_history[iteration] = losses
            lr_history[iteration] = opt.current_lr
            loss_sum = np.zeros(5)
            start = time.time()
            bad_video_num = 0

        # Evaluation, and save model
        if (iteration % opt.save_checkpoint_every
                == 0) and (epoch >= opt.min_epoch_when_save):
            eval_kwargs = {
                'split': 'val',
                'val_all_metrics': 0,
                'topN': 100,
            }

            eval_kwargs.update(vars(opt))

            # eval_kwargs['num_vids_eval'] = int(491)
            eval_kwargs['topN'] = 100

            eval_kwargs2 = {
                'split': 'val',
                'val_all_metrics': 1,
                'num_vids_eval': 4917,
            }
            eval_kwargs2.update(vars(opt))

            if not opt.num_vids_eval:
                eval_kwargs['num_vids_eval'] = int(4917.)
                eval_kwargs2['num_vids_eval'] = 4917

            crits = [tap_crit, cg_crit]
            pred_json_path_T = os.path.join(save_folder, 'pred_sent',
                                            'pred_num{}_iter{}.json')

            # if 'alter' in opt.training_mode:
            if flag_training_what == 'tap':
                eval_kwargs['topN'] = 1000
                predictions, eval_score, val_loss = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                            iteration),
                    eval_kwargs,
                    flag_eval_what='tap')
            else:
                if vars(opt).get('fast_eval_cg', False) == False:
                    predictions, eval_score, val_loss = eval_utils.eval_split(
                        allmodels,
                        crits,
                        loader,
                        pred_json_path_T.format(eval_kwargs['num_vids_eval'],
                                                iteration),
                        eval_kwargs,
                        flag_eval_what='tap_cg')

                predictions2, eval_score2, val_loss2 = eval_utils.eval_split(
                    allmodels,
                    crits,
                    loader,
                    pred_json_path_T.format(eval_kwargs2['num_vids_eval'],
                                            iteration),
                    eval_kwargs2,
                    flag_eval_what='cg')

                if (not vars(opt).get('fast_eval_cg', False)
                        == False) or (not vars(opt).get(
                            'fast_eval_cg_top10', False) == False):
                    eval_score = eval_score2
                    val_loss = val_loss2
                    predictions = predictions2

            # else:
            #    predictions, eval_score, val_loss = eval_utils.eval_split(allmodels, crits, loader, pred_json_path,
            #                                                              eval_kwargs,
            #                                                              flag_eval_what=flag_training_what)

            f_f1 = lambda x, y: 2 * x * y / (x + y)
            f1 = f_f1(eval_score['Recall'], eval_score['Precision']).mean()
            if flag_training_what != 'tap':  # if only train tap, use the mean of precision and recall as final score
                current_score = np.array(eval_score['METEOR']).mean() * 100
            else:  # if train tap_cg, use avg_meteor as final score
                current_score = f1

            for model in allmodels:
                for name, param in model.named_parameters():
                    tf_writer.add_histogram(name,
                                            param.clone().cpu().data.numpy(),
                                            iteration,
                                            bins=10)
                    if param.grad is not None:
                        tf_writer.add_histogram(
                            name + '_grad',
                            param.grad.clone().cpu().data.numpy(),
                            iteration,
                            bins=10)

            tf_writer.add_scalar('val_tap_loss', val_loss[0], iteration)
            tf_writer.add_scalar('val_cg_loss', val_loss[1], iteration)
            tf_writer.add_scalar('val_tap_prop_loss', val_loss[3], iteration)
            tf_writer.add_scalar('val_tap_bound_loss', val_loss[4], iteration)
            tf_writer.add_scalar('val_total_loss', val_loss[2], iteration)
            tf_writer.add_scalar('val_score', current_score, iteration)
            if flag_training_what != 'tap':
                tf_writer.add_scalar('val_score_gt_METEOR',
                                     np.array(eval_score2['METEOR']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_Bleu_4',
                                     np.array(eval_score2['Bleu_4']).mean(),
                                     iteration)
                tf_writer.add_scalar('val_score_gt_CIDEr',
                                     np.array(eval_score2['CIDEr']).mean(),
                                     iteration)
            tf_writer.add_scalar('val_recall', eval_score['Recall'].mean(),
                                 iteration)
            tf_writer.add_scalar('val_precision',
                                 eval_score['Precision'].mean(), iteration)
            tf_writer.add_scalar('f1', f1, iteration)

            val_result_history[iteration] = {
                'val_loss': val_loss,
                'eval_score': eval_score
            }

            if flag_training_what == 'tap':
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}'
                    .format(iteration, current_score, eval_score))
            else:
                mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score.items()
                }
                gt_mean_score = {
                    k: np.array(v).mean()
                    for k, v in eval_score2.items()
                }

                metrics = ['Bleu_4', 'CIDEr', 'METEOR', 'ROUGE_L']
                gt_avg_score = np.array([
                    v for metric, v in gt_mean_score.items()
                    if metric in metrics
                ]).sum()
                logger.info(
                    'Validation the result of iter {}, score(f1/meteor):{},\n all:{}\n mean:{} \n\n gt:{} \n mean:{}\n avg_score: {}'
                    .format(iteration, current_score, eval_score, mean_score,
                            eval_score2, gt_mean_score, gt_avg_score))

            # Save model .pth
            saved_pth = {
                'iteration': iteration,
                'cg_model': cg_model.state_dict(),
                'tap_model': tap_model.state_dict(),
                'cg_optimizer': cg_optimizer.state_dict(),
                'tap_optimizer': tap_optimizer.state_dict(),
            }

            if opt.save_all_checkpoint:
                checkpoint_path = os.path.join(
                    save_folder, 'model_iter_{}.pth'.format(iteration))
            else:
                checkpoint_path = os.path.join(save_folder, 'model.pth')
            torch.save(saved_pth, checkpoint_path)
            logger.info('Save model at iter {} to checkpoint file {}.'.format(
                iteration, checkpoint_path))

            # save info.pkl
            if current_score > best_val_score:
                best_val_score = current_score
                best_epoch = epoch
                saved_info['best'] = {
                    'opt': opt,
                    'iter': iteration,
                    'epoch': epoch,
                    'iterators': loader.iterators,
                    'flag_training_what': flag_training_what,
                    'split_ix': loader.split_ix,
                    'best_val_score': best_val_score,
                    'vocab': loader.get_vocab(),
                }

                best_checkpoint_path = os.path.join(save_folder,
                                                    'model-best.pth')
                torch.save(saved_pth, best_checkpoint_path)
                logger.info(
                    'Save Best-model at iter {} to checkpoint file.'.format(
                        iteration))

            saved_info['last'] = {
                'opt': opt,
                'iter': iteration,
                'epoch': epoch,
                'iterators': loader.iterators,
                'flag_training_what': flag_training_what,
                'split_ix': loader.split_ix,
                'best_val_score': best_val_score,
                'vocab': loader.get_vocab(),
            }
            saved_info['history'] = {
                'val_result_history': val_result_history,
                'loss_history': loss_history,
                'lr_history': lr_history,
            }
            with open(os.path.join(save_folder, 'info.pkl'), 'w') as f:
                cPickle.dump(saved_info, f)
                logger.info('Save info to info.pkl')

            # Stop criterion
            if epoch >= len(flag_training_whats):
                tf_writer.close()
                break
def train(opt):
    if vars(opt).get('start_from', None) is not None:
        opt.checkpoint_path = opt.start_from
        opt.id = opt.checkpoint_path.split('/')[-1]
        print('Point to folder: {}'.format(opt.checkpoint_path))
    else:
        opt.id = datetime.datetime.now().strftime(
            '%Y%m%d_%H%M%S') + '_' + opt.caption_model
        opt.checkpoint_path = os.path.join(opt.checkpoint_path, opt.id)

        if not os.path.exists(opt.checkpoint_path):
            os.makedirs(opt.checkpoint_path)
        print('Create folder: {}'.format(opt.checkpoint_path))

    # Deal with feature things before anything
    opt.use_att = utils.if_use_att(opt.caption_model)
    # opt.use_att = False
    if opt.use_box: opt.att_feat_size = opt.att_feat_size + 5

    loader = DataLoader_UP(opt)
    opt.vocab_size = loader.vocab_size
    if opt.use_rela == 1:
        opt.rela_dict_size = loader.rela_dict_size
    opt.seq_length = loader.seq_length
    use_rela = getattr(opt, 'use_rela', 0)

    try:
        tb_summary_writer = tf and tf.compat.v1.summary.FileWriter(
            opt.checkpoint_path)
    except:
        print('Set tensorboard error!')
        pdb.set_trace()

    infos = {}
    histories = {}
    if opt.start_from is not None:
        # open old infos and check if models are compatible
        with open(os.path.join(opt.checkpoint_path, 'infos.pkl')) as f:
            infos = cPickle.load(f)
            saved_model_opt = infos['opt']
            need_be_same = [
                "caption_model", "rnn_type", "rnn_size", "num_layers"
            ]
            for checkme in need_be_same:
                assert vars(saved_model_opt)[checkme] == vars(
                    opt
                )[checkme], "Command line argument and saved model disagree on '%s' " % checkme

        if os.path.isfile(os.path.join(opt.checkpoint_path, 'histories.pkl')):
            with open(os.path.join(opt.checkpoint_path, 'histories.pkl')) as f:
                histories = cPickle.load(f)

    iteration = infos.get('iter', 0)
    epoch = infos.get('epoch', 0)

    val_result_history = histories.get('val_result_history', {})
    loss_history = histories.get('loss_history', {})
    lr_history = histories.get('lr_history', {})
    ss_prob_history = histories.get('ss_prob_history', {})

    loader.iterators = infos.get('iterators', loader.iterators)
    loader.split_ix = infos.get('split_ix', loader.split_ix)
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)

    model = models.setup(opt).cuda()
    # dp_model = torch.nn.DataParallel(model)
    # dp_model = torch.nn.DataParallel(model, [0,2,3])
    dp_model = model

    print('### Model summary below###\n {}\n'.format(str(model)))
    model_params = sum(p.numel() for p in model.parameters()
                       if p.requires_grad)
    print('model parameter:{}'.format(model_params))

    update_lr_flag = True
    # Assure in training mode
    dp_model.train()
    parameters = model.named_children()
    crit = utils.LanguageModelCriterion()
    rl_crit = utils.RewardCriterion()

    optimizer = utils.build_optimizer(
        filter(lambda p: p.requires_grad, model.parameters()), opt)

    optimizer.zero_grad()
    accumulate_iter = 0
    train_loss = 0
    reward = np.zeros([1, 1])

    while True:
        if update_lr_flag:
            # Assign the learning rate
            if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                frac = (epoch - opt.learning_rate_decay_start
                        ) // opt.learning_rate_decay_every
                decay_factor = opt.learning_rate_decay_rate**frac
                opt.current_lr = opt.learning_rate * decay_factor
            else:
                opt.current_lr = opt.learning_rate
            utils.set_lr(optimizer, opt.current_lr)
            # Assign the scheduled sampling prob
            if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                frac = (epoch - opt.scheduled_sampling_start
                        ) // opt.scheduled_sampling_increase_every
                opt.ss_prob = min(opt.scheduled_sampling_increase_prob * frac,
                                  opt.scheduled_sampling_max_prob)
                model.ss_prob = opt.ss_prob

            # If start self critical training
            if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                sc_flag = True
                init_scorer(opt.cached_tokens)
            else:
                sc_flag = False

            update_lr_flag = False

        start = time.time()
        # Load data from train split (0)
        data = loader.get_batch(opt.train_split)
        # print('Read data:', time.time() - start)

        torch.cuda.synchronize()
        start = time.time()

        fc_feats = None
        att_feats = None
        att_masks = None
        ssg_data = None
        rela_data = None

        if getattr(opt, 'use_ssg', 0) == 1:
            if getattr(opt, 'use_isg', 0) == 1:
                tmp = [
                    data['fc_feats'], data['labels'], data['masks'],
                    data['att_feats'], data['att_masks'],
                    data['isg_rela_matrix'], data['isg_rela_masks'],
                    data['isg_obj'], data['isg_obj_masks'], data['isg_attr'],
                    data['isg_attr_masks'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks']
                ]

                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, labels, masks, att_feats, att_masks, \
                isg_rela_matrix, isg_rela_masks, isg_obj, isg_obj_masks, isg_attr, isg_attr_masks, \
                ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks = tmp

                # image graph domain
                isg_data = {}
                isg_data['att_feats'] = att_feats
                isg_data['att_masks'] = att_masks

                isg_data['isg_rela_matrix'] = isg_rela_matrix
                isg_data['isg_rela_masks'] = isg_rela_masks
                isg_data['isg_obj'] = isg_obj
                isg_data['isg_obj_masks'] = isg_obj_masks
                isg_data['isg_attr'] = isg_attr
                isg_data['isg_attr_masks'] = isg_attr_masks
                # text graph domain
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
            else:
                tmp = [
                    data['fc_feats'], data['ssg_rela_matrix'],
                    data['ssg_rela_masks'], data['ssg_obj'],
                    data['ssg_obj_masks'], data['ssg_attr'],
                    data['ssg_attr_masks'], data['labels'], data['masks']
                ]
                tmp = [
                    _ if _ is None else torch.from_numpy(_).cuda() for _ in tmp
                ]
                fc_feats, ssg_rela_matrix, ssg_rela_masks, ssg_obj, ssg_obj_masks, ssg_attr, ssg_attr_masks, labels, masks = tmp
                ssg_data = {}
                ssg_data['ssg_rela_matrix'] = ssg_rela_matrix
                ssg_data['ssg_rela_masks'] = ssg_rela_masks
                ssg_data['ssg_obj'] = ssg_obj
                ssg_data['ssg_obj_masks'] = ssg_obj_masks
                ssg_data['ssg_attr'] = ssg_attr

                isg_data = None
                ssg_data['ssg_attr_masks'] = ssg_attr_masks
        else:
            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else torch.from_numpy(_).cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

        if not sc_flag:
            # loss = crit(dp_model(model_zh,model_en,itow_zh,itow, fc_feats, labels, isg_data, ssg_data), labels[:, 1:], masks[:, 1:])
            # print('ssg:')
            # print(ssg_data['ssg_obj'])
            # print('predict:')
            # print(dp_model(fc_feats, labels, isg_data, ssg_data))
            # print('label:')
            # print(labels[:, 1:])
            loss = crit(dp_model(fc_feats, labels, isg_data, ssg_data),
                        labels[:, 1:], masks[:, 1:])
        else:
            gen_result, sample_logprobs = dp_model(fc_feats,
                                                   isg_data,
                                                   ssg_data,
                                                   opt={'sample_max': 0},
                                                   mode='sample')
            reward = get_self_critical_reward(dp_model, fc_feats, isg_data,
                                              ssg_data, data, gen_result, opt)
            loss = rl_crit(sample_logprobs, gen_result.data,
                           torch.from_numpy(reward).float().cuda())

        accumulate_iter = accumulate_iter + 1
        loss = loss / opt.accumulate_number
        loss.backward()
        if accumulate_iter % opt.accumulate_number == 0:
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            optimizer.zero_grad()
            iteration += 1
            accumulate_iter = 0
            train_loss = loss.item() * opt.accumulate_number
            end = time.time()

            if not sc_flag:
                print("{}/{}/{}|train_loss={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, train_loss, end - start))
            else:
                print("{}/{}/{}|avg_reward={:.3f}|time/batch={:.3f}" \
                      .format(opt.id, iteration, epoch, np.mean(reward[:, 0]), end - start))

        torch.cuda.synchronize()

        # Update the iteration and epoch
        if data['bounds']['wrapped']:
            epoch += 1
            update_lr_flag = True

        # Write the training loss summary
        if (iteration % opt.losses_log_every == 0) and (iteration != 0):
            add_summary_value(tb_summary_writer, 'train_loss', train_loss,
                              iteration)
            add_summary_value(tb_summary_writer, 'learning_rate',
                              opt.current_lr, iteration)
            add_summary_value(tb_summary_writer, 'scheduled_sampling_prob',
                              model.ss_prob, iteration)
            if sc_flag:
                add_summary_value(tb_summary_writer, 'avg_reward',
                                  np.mean(reward[:, 0]), iteration)

            loss_history[iteration] = train_loss if not sc_flag else np.mean(
                reward[:, 0])
            lr_history[iteration] = opt.current_lr
            ss_prob_history[iteration] = model.ss_prob

        # make evaluation on validation set, and save model
        # if (iteration %2 == 0) and (iteration != 0):
        if (iteration % opt.save_checkpoint_every == 0) and (iteration != 0):
            # eval model
            if use_rela:
                eval_kwargs = {
                    'split': 'val',
                    'dataset': opt.input_json,
                    'use_real': 1
                }
            else:
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
            eval_kwargs.update(vars(opt))
            # val_loss, predictions, lang_stats = eval_utils.eval_split(model_zh,model_en,itow_zh,itow, dp_model, crit, loader, eval_kwargs)
            val_loss, predictions, lang_stats = eval_utils.eval_split(
                dp_model, crit, loader, eval_kwargs)

            # Write validation result into summary
            add_summary_value(tb_summary_writer, 'validation loss', val_loss,
                              iteration)
            if lang_stats is not None:
                for k, v in lang_stats.items():
                    add_summary_value(tb_summary_writer, k, v, iteration)
            val_result_history[iteration] = {
                'loss': val_loss,
                'lang_stats': lang_stats,
                'predictions': predictions
            }

            # Save model if is improving on validation result
            if opt.language_eval == 1:
                current_score = lang_stats['CIDEr']
            else:
                current_score = -val_loss

            best_flag = False
            if True:  # if true
                save_id = iteration / opt.save_checkpoint_every
                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("model saved to {}".format(checkpoint_path))
                optimizer_path = os.path.join(opt.checkpoint_path,
                                              'optimizer.pth')
                torch.save(optimizer.state_dict(), optimizer_path)

                # Dump miscalleous informations
                infos['iter'] = iteration
                infos['epoch'] = epoch
                infos['iterators'] = loader.iterators
                infos['split_ix'] = loader.split_ix
                infos['best_val_score'] = best_val_score
                infos['opt'] = opt
                infos['vocab'] = loader.get_vocab()

                histories['val_result_history'] = val_result_history
                histories['loss_history'] = loss_history
                histories['lr_history'] = lr_history
                histories['ss_prob_history'] = ss_prob_history
                with open(os.path.join(opt.checkpoint_path, 'infos.pkl'),
                          'wb') as f:
                    cPickle.dump(infos, f)
                with open(os.path.join(opt.checkpoint_path, 'histories.pkl'),
                          'wb') as f:
                    cPickle.dump(histories, f)

                if best_flag:
                    checkpoint_path = os.path.join(opt.checkpoint_path,
                                                   'model-best.pth')
                    torch.save(model.state_dict(), checkpoint_path)
                    print("model saved to {}".format(checkpoint_path))
                    with open(
                            os.path.join(opt.checkpoint_path,
                                         'infos-best.pkl'), 'wb') as f:
                        cPickle.dump(infos, f)

        # Stop if reaching max epochs
        if epoch >= opt.max_epochs and opt.max_epochs != -1:
            break
Exemplo n.º 32
0
def train(opt):
    # on node-13 this line cauuses a bug
    from torch.utils.tensorboard import SummaryWriter

    ################################
    # Build dataloader
    ################################

    # the loader here needs to be fixed actually...
    # so that data loading is correct
    # need to modify opt here so everything else is correct

    loader = DataLoader(opt)
    opt.vocab_size = loader.vocab_size
    opt.seq_length = loader.seq_length

    ##########################
    # Initialize infos
    ##########################
    infos = {
        'iter': 0,
        'epoch': 0,
        'loader_state_dict': None,
        'vocab': loader.get_vocab(),
    }
    # Load old infos(if there is) and check if models are compatible
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'infos_' + opt.id + '.pkl')):
        raise Exception("not implemented")

    infos['opt'] = opt

    #########################
    # Build logger
    #########################
    # naive dict logger
    histories = defaultdict(dict)
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl')):
        with open(os.path.join(opt.start_from, 'histories_' + opt.id + '.pkl'),
                  'rb') as f:
            histories.update(utils.pickle_load(f))

    # tensorboard logger
    tb_summary_writer = SummaryWriter(opt.checkpoint_path)

    ##########################
    # Build model
    ##########################
    # opt.vocab = loader.get_vocab()
    opt.vocab = loader.get_vocab()
    model = TransformerLM(opt).cuda()  # only set up the language model
    del opt.vocab
    # Load pretrained weights:
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, 'model.pth')):
        model.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'model.pth')))

    # Wrap generation model with loss function(used for training)
    # This allows loss function computed separately on each machine
    lw_model = LossWrapper(model, opt)
    # Wrap with dataparallel
    dp_model = torch.nn.DataParallel(model)
    dp_lw_model = torch.nn.DataParallel(lw_model)

    ##########################
    #  Build optimizer
    ##########################
    if opt.noamopt:
        assert opt.caption_model == 'transformer', 'noamopt can only work with transformer'
        optimizer = utils.get_std_opt(model,
                                      factor=opt.noamopt_factor,
                                      warmup=opt.noamopt_warmup)
    elif opt.reduce_on_plateau:
        optimizer = utils.build_optimizer(model.parameters(), opt)
        optimizer = utils.ReduceLROnPlateau(optimizer, factor=0.5, patience=3)
    else:
        optimizer = utils.build_optimizer(model.parameters(), opt)
    # Load the optimizer
    if opt.start_from is not None and os.path.isfile(
            os.path.join(opt.start_from, "optimizer.pth")):
        optimizer.load_state_dict(
            torch.load(os.path.join(opt.start_from, 'optimizer.pth')))

    #########################
    # Get ready to start
    #########################
    iteration = infos['iter']
    epoch = infos['epoch']
    # For back compatibility
    if 'iterators' in infos:
        infos['loader_state_dict'] = {
            split: {
                'index_list': infos['split_ix'][split],
                'iter_counter': infos['iterators'][split]
            }
            for split in ['train', 'val', 'test']
        }
    loader.load_state_dict(infos['loader_state_dict'])
    if opt.load_best_score == 1:
        best_val_score = infos.get('best_val_score', None)
    if opt.noamopt:
        optimizer._step = iteration
    # flag indicating finish of an epoch
    # Always set to True at the beginning to initialize the lr or etc.
    epoch_done = True
    # Assure in training mode
    dp_lw_model.train()

    # Start training
    try:
        while True:
            if epoch_done:
                if not opt.noamopt and not opt.reduce_on_plateau:
                    # Assign the learning rate
                    if epoch > opt.learning_rate_decay_start and opt.learning_rate_decay_start >= 0:
                        frac = (epoch - opt.learning_rate_decay_start
                                ) // opt.learning_rate_decay_every
                        decay_factor = opt.learning_rate_decay_rate**frac
                        opt.current_lr = opt.learning_rate * decay_factor
                    else:
                        opt.current_lr = opt.learning_rate
                    utils.set_lr(optimizer,
                                 opt.current_lr)  # set the decayed rate
                # Assign the scheduled sampling prob
                if epoch > opt.scheduled_sampling_start and opt.scheduled_sampling_start >= 0:
                    frac = (epoch - opt.scheduled_sampling_start
                            ) // opt.scheduled_sampling_increase_every
                    opt.ss_prob = min(
                        opt.scheduled_sampling_increase_prob * frac,
                        opt.scheduled_sampling_max_prob)
                    model.ss_prob = opt.ss_prob

                # If start self critical training
                if opt.self_critical_after != -1 and epoch >= opt.self_critical_after:
                    sc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    sc_flag = False

                # If start structure loss training
                if opt.structure_after != -1 and epoch >= opt.structure_after:
                    struc_flag = True
                    init_scorer(opt.cached_tokens)
                else:
                    struc_flag = False

                epoch_done = False

            start = time.time()
            # Load data from train split (0)
            data = loader.get_batch('train')
            print('Read data:', time.time() - start)

            torch.cuda.synchronize()
            start = time.time()

            tmp = [
                data['fc_feats'], data['att_feats'], data['labels'],
                data['masks'], data['att_masks']
            ]
            tmp = [_ if _ is None else _.cuda() for _ in tmp]
            fc_feats, att_feats, labels, masks, att_masks = tmp

            optimizer.zero_grad()
            model_out = dp_lw_model(fc_feats, att_feats, labels, masks,
                                    att_masks, data['gts'],
                                    torch.arange(0, len(data['gts'])), sc_flag,
                                    struc_flag)

            loss = model_out['loss'].mean()

            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.item()
            torch.cuda.synchronize()
            end = time.time()
            if struc_flag:
                print(
                    "iter {} (epoch {}), train_loss = {:.3f}, lm_loss = {:.3f}, struc_loss = {:.3f}, time/batch = {:.3f}" \
                    .format(iteration, epoch, train_loss, model_out['lm_loss'].mean().item(),
                            model_out['struc_loss'].mean().item(), end - start))
            elif not sc_flag:
                print("iter {} (epoch {}), train_loss = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, train_loss, end - start))
            else:
                print("iter {} (epoch {}), avg_reward = {:.3f}, time/batch = {:.3f}" \
                      .format(iteration, epoch, model_out['reward'].mean(), end - start))

            # Update the iteration and epoch
            iteration += 1
            if data['bounds']['wrapped']:
                epoch += 1
                epoch_done = True

            # Write the training loss summary
            if (iteration % opt.losses_log_every == 0):
                tb_summary_writer.add_scalar('train_loss', train_loss,
                                             iteration)
                if opt.noamopt:
                    opt.current_lr = optimizer.rate()
                elif opt.reduce_on_plateau:
                    opt.current_lr = optimizer.current_lr
                tb_summary_writer.add_scalar('learning_rate', opt.current_lr,
                                             iteration)
                tb_summary_writer.add_scalar('scheduled_sampling_prob',
                                             model.ss_prob, iteration)
                if sc_flag:
                    tb_summary_writer.add_scalar('avg_reward',
                                                 model_out['reward'].mean(),
                                                 iteration)
                elif struc_flag:
                    tb_summary_writer.add_scalar(
                        'lm_loss', model_out['lm_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'struc_loss', model_out['struc_loss'].mean().item(),
                        iteration)
                    tb_summary_writer.add_scalar(
                        'reward', model_out['reward'].mean().item(), iteration)

                histories['loss_history'][
                    iteration] = train_loss if not sc_flag else model_out[
                        'reward'].mean()
                histories['lr_history'][iteration] = opt.current_lr
                histories['ss_prob_history'][iteration] = model.ss_prob

            # update infos
            infos['iter'] = iteration
            infos['epoch'] = epoch
            infos['loader_state_dict'] = loader.state_dict()

            # make evaluation on validation set, and save model
            if (iteration % opt.save_checkpoint_every == 0):
                # eval model
                eval_kwargs = {'split': 'val', 'dataset': opt.input_json}
                eval_kwargs.update(vars(opt))
                val_loss, predictions, lang_stats = eval_utils.eval_split(
                    dp_model, lw_model.crit, loader, eval_kwargs)

                if opt.reduce_on_plateau:
                    if lang_stats is not None:
                        if 'CIDEr' in lang_stats:
                            optimizer.scheduler_step(-lang_stats['CIDEr'])
                        else:
                            optimizer.scheduler_step(val_loss)
                    else:
                        optimizer.scheduler_step(val_loss)
                # Write validation result into summary
                tb_summary_writer.add_scalar('validation loss', val_loss,
                                             iteration)
                if lang_stats is not None:
                    for k, v in lang_stats.items():
                        tb_summary_writer.add_scalar(k, v, iteration)
                histories['val_result_history'][iteration] = {
                    'loss': val_loss,
                    'lang_stats': lang_stats,
                    'predictions': predictions
                }

                # Save model if is improving on validation result
                if opt.language_eval == 1:
                    current_score = lang_stats['CIDEr']
                else:
                    current_score = -val_loss

                best_flag = False

                if best_val_score is None or current_score > best_val_score:
                    best_val_score = current_score
                    best_flag = True

                # Dump miscellaneous information
                infos['best_val_score'] = best_val_score

                utils.save_checkpoint(opt, model, infos, optimizer, histories)
                if opt.save_history_ckpt:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append=str(iteration))

                if best_flag:
                    utils.save_checkpoint(opt,
                                          model,
                                          infos,
                                          optimizer,
                                          append='best')

            # Stop if reaching max epochs
            if epoch >= opt.max_epochs and opt.max_epochs != -1:
                break

    except (RuntimeError, KeyboardInterrupt):
        print('Save ckpt on exception ...')
        utils.save_checkpoint(opt, model, infos, optimizer)
        print('Save ckpt done.')
        stack_trace = traceback.format_exc()
        print(stack_trace)
def train(train_loader,
          val_loader,
          model,
          crit,
          optimizer,
          lr_scheduler,
          opt,
          rl_crit=None):
    model.train()
    model = nn.DataParallel(model)
    # lowest val loss
    best_loss = None
    for epoch in range(opt.epochs):
        lr_scheduler.step()

        iteration = 0
        # If start self crit training
        if opt.self_crit_after != -1 and epoch >= opt.self_crit_after:
            sc_flag = True
            init_cider_scorer(opt.cached_tokens)
        else:
            sc_flag = False

        for data in train_loader:
            torch.cuda.synchronize()
            fc_feats = Variable(data['fc_feats']).cuda()
            labels = Variable(data['labels']).long().cuda()
            masks = Variable(data['masks']).cuda()
            if not sc_flag:
                seq_probs, predicts = model(fc_feats, labels)
                loss = crit(seq_probs, labels[:, 1:], masks[:, 1:])
            else:
                gen_result, sample_logprobs = model.sample(fc_feats, vars(opt))
                # print(gen_result)
                reward = get_self_critical_reward(model, fc_feats, data,
                                                  gen_result)
                loss = rl_crit(
                    sample_logprobs, gen_result,
                    Variable(torch.from_numpy(reward).float().cuda()))

            optimizer.zero_grad()
            loss.backward()
            utils.clip_gradient(optimizer, opt.grad_clip)
            optimizer.step()
            train_loss = loss.data[0]
            torch.cuda.synchronize()
            iteration += 1

            if not sc_flag:
                print("iter %d (epoch %d), train_loss = %.6f" %
                      (iteration, epoch, train_loss))
            else:
                print("iter %d (epoch %d), avg_reward = %.3f" %
                      (iteration, epoch, np.mean(reward[:, 0])))

        # lowest val loss

        if epoch % opt.save_checkpoint_every == 0:
            checkpoint_path = os.path.join(opt.checkpoint_path,
                                           'model_%d.pth' % (epoch))
            torch.save(model.state_dict(), checkpoint_path)
            print("model saved to %s" % (checkpoint_path))
            val_loss = val(val_loader, model, crit)
            print("Val loss is: %.6f" % (val_loss))
            model.train()
            if best_loss is None or val_loss < best_loss:
                print("(epoch %d), now lowest val loss is %.6f" %
                      (epoch, val_loss))
                checkpoint_path = os.path.join(opt.checkpoint_path,
                                               'model_best.pth')
                torch.save(model.state_dict(), checkpoint_path)
                print("best model saved to %s" % (checkpoint_path))
                best_loss = val_loss