Exemplo n.º 1
0
def build_model(C, vocab):

    decoder = Decoder(rnn_type=C.decoder.rnn_type,
                      num_layers=C.decoder.rnn_num_layers,
                      num_directions=C.decoder.rnn_num_directions,
                      feat_size=C.feat.size,
                      feat_len=C.loader.frame_sample_len,
                      embedding_size=C.vocab.embedding_size,
                      hidden_size=C.decoder.rnn_hidden_size,
                      attn_size=C.decoder.rnn_attn_size,
                      output_size=vocab.n_vocabs,
                      rnn_dropout=C.decoder.rnn_dropout)
    if C.pretrained_decoder_fpath is not None:
        decoder.load_state_dict(
            torch.load(C.pretrained_decoder_fpath)['decoder'])
        print("Pretrained decoder is loaded from {}".format(
            C.pretrained_decoder_fpath))
    #全局和局部重构器
    if C.reconstructor is None:
        reconstructor = None
    elif C.reconstructor.type == 'global':
        reconstructor = GlobalReconstructor(
            rnn_type=C.reconstructor.rnn_type,
            num_layers=C.reconstructor.rnn_num_layers,
            num_directions=C.reconstructor.rnn_num_directions,
            decoder_size=C.decoder.rnn_hidden_size,
            hidden_size=C.reconstructor.rnn_hidden_size,
            rnn_dropout=C.reconstructor.rnn_dropout)
    else:
        reconstructor = LocalReconstructor(
            rnn_type=C.reconstructor.rnn_type,
            num_layers=C.reconstructor.rnn_num_layers,
            num_directions=C.reconstructor.rnn_num_directions,
            decoder_size=C.decoder.rnn_hidden_size,
            hidden_size=C.reconstructor.rnn_hidden_size,
            attn_size=C.reconstructor.rnn_attn_size,
            rnn_dropout=C.reconstructor.rnn_dropout)
    if C.pretrained_reconstructor_fpath is not None:
        reconstructor.load_state_dict(
            torch.load(C.pretrained_reconstructor_fpath)['reconstructor'])
        print("Pretrained reconstructor is loaded from {}".format(
            C.pretrained_reconstructor_fpath))

    model = CaptionGenerator(decoder, reconstructor, C.loader.max_caption_len,
                             vocab)
    model.cuda()
    return model
Exemplo n.º 2
0
def build_reconstructor():
    if C.reconstructor_type == "local":
        model = LocalReconstructor(
            model_name=C.reconstructor_model,
            n_layers=C.reconstructor_n_layers,
            decoder_hidden_size=C.decoder_hidden_size,
            hidden_size=C.reconstructor_hidden_size,
            dropout=C.reconstructor_dropout,
            decoder_dropout=C.reconstructor_decoder_dropout,
            attn_size=C.reconstructor_attn_size)
    elif C.reconstructor_type == "global":
        model = GlobalReconstructor(
            model_name=C.reconstructor_model,
            n_layers=C.reconstructor_n_layers,
            decoder_hidden_size=C.decoder_hidden_size,
            hidden_size=C.reconstructor_hidden_size,
            dropout=C.reconstructor_dropout,
            decoder_dropout=C.reconstructor_decoder_dropout,
            caption_max_len=C.caption_max_len)
    else:
        raise NotImplementedError("Unknown reconstructor: {}".format(
            C.reconstructor_type))
    model = model.to(C.device)
    loss = torch.nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=C.reconstructor_learning_rate,
                                 weight_decay=C.reconstructor_weight_decay,
                                 amsgrad=C.reconstructor_use_amsgrad)
    lambda_reg = torch.autograd.Variable(torch.tensor(0.01),
                                         requires_grad=True)
    lambda_reg = lambda_reg.to(C.device)

    reconstructor = {
        'model': model,
        'loss': loss,
        'optimizer': optimizer,
        'lambda_reg': lambda_reg,
    }
    return reconstructor
def main():
    a = argparse.ArgumentParser()
    a.add_argument("--debug", "-D", action="store_true")
    a.add_argument("--loss_only", "-L", action="store_true")
    args = a.parse_args()

    print("MODEL ID: {}".format(C.id))
    print("DEBUG MODE: {}".format(['OFF', 'ON'][args.debug]))

    if not args.debug:
        summary_writer = SummaryWriter(C.log_dpath)
    """ Load DataLoader """
    MSVD = _MSVD(C)
    vocab = MSVD.vocab
    train_data_loader = iter(cycle(MSVD.train_data_loader))
    val_data_loader = iter(cycle(MSVD.val_data_loader))

    print('n_vocabs: {} ({}), n_words: {} ({}). MIN_COUNT: {}'.format(
        vocab.n_vocabs, vocab.n_vocabs_untrimmed, vocab.n_words,
        vocab.n_words_untrimmed, C.min_count))
    """ Build Decoder """
    decoder = Decoder(model_name=C.decoder_model,
                      n_layers=C.decoder_n_layers,
                      encoder_size=C.encoder_output_size,
                      embedding_size=C.embedding_size,
                      embedding_scale=C.embedding_scale,
                      hidden_size=C.decoder_hidden_size,
                      attn_size=C.decoder_attn_size,
                      output_size=vocab.n_vocabs,
                      embedding_dropout=C.embedding_dropout,
                      dropout=C.decoder_dropout,
                      out_dropout=C.decoder_out_dropout)
    decoder = decoder.to(C.device)
    decoder_loss_func = nn.CrossEntropyLoss()
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=C.decoder_learning_rate,
                                   weight_decay=C.decoder_weight_decay,
                                   amsgrad=C.decoder_use_amsgrad)
    decoder_lambda = torch.autograd.Variable(torch.tensor(0.001),
                                             requires_grad=True)
    decoder_lambda = decoder_lambda.to(C.device)
    """ Build Reconstructor """
    if C.use_recon:
        if C.reconstructor_type == "local":
            reconstructor = LocalReconstructor(
                model_name=C.reconstructor_model,
                n_layers=C.reconstructor_n_layers,
                decoder_hidden_size=C.decoder_hidden_size,
                hidden_size=C.reconstructor_hidden_size,
                dropout=C.reconstructor_dropout,
                decoder_dropout=C.reconstructor_decoder_dropout,
                attn_size=C.reconstructor_attn_size)
        elif C.reconstructor_type == "global":
            reconstructor = GlobalReconstructor(
                model_name=C.reconstructor_model,
                n_layers=C.reconstructor_n_layers,
                decoder_hidden_size=C.decoder_hidden_size,
                hidden_size=C.reconstructor_hidden_size,
                dropout=C.reconstructor_dropout,
                decoder_dropout=C.reconstructor_decoder_dropout,
                caption_max_len=C.caption_max_len)
        else:
            raise NotImplementedError("Unknown reconstructor: {}".format(
                C.reconstructor_type))
        reconstructor = reconstructor.to(C.device)
        reconstructor_loss_func = nn.MSELoss()
        reconstructor_optimizer = optim.Adam(
            reconstructor.parameters(),
            lr=C.reconstructor_learning_rate,
            weight_decay=C.reconstructor_weight_decay,
            amsgrad=C.reconstructor_use_amsgrad)
        reconstructor_lambda = torch.autograd.Variable(torch.tensor(0.01),
                                                       requires_grad=True)
        reconstructor_lambda = reconstructor_lambda.to(C.device)
        loss_lambda = torch.autograd.Variable(torch.tensor(1.),
                                              requires_grad=True)
        loss_lambda = loss_lambda.to(C.device)
    """ Train """
    train_loss = 0
    if C.use_recon:
        train_dec_loss = 0
        train_rec_loss = 0
    for iteration, batch in enumerate(train_data_loader, 1):
        if C.use_recon:
            loss, decoder_loss, _, recon_loss = dec_rec_step(
                batch,
                decoder,
                decoder_loss_func,
                decoder_lambda,
                decoder_optimizer,
                reconstructor,
                reconstructor_loss_func,
                reconstructor_lambda,
                reconstructor_optimizer,
                loss_lambda,
                is_train=True)
            train_dec_loss += decoder_loss
            train_rec_loss += recon_loss
        else:
            loss, _ = dec_step(batch,
                               decoder,
                               decoder_loss_func,
                               decoder_lambda,
                               decoder_optimizer,
                               is_train=True)
        train_loss += loss
        """ Log Train Progress """
        if args.debug or iteration % C.log_every == 0:
            train_loss /= C.log_every
            if C.use_recon:
                train_dec_loss /= C.log_every
                train_rec_loss /= C.log_every

            if not args.debug:
                summary_writer.add_scalar(C.tx_train_loss, train_loss,
                                          iteration)
                summary_writer.add_scalar(C.tx_lambda_decoder,
                                          decoder_lambda.item(), iteration)
                if C.use_recon:
                    summary_writer.add_scalar(C.tx_train_loss_decoder,
                                              train_dec_loss, iteration)
                    summary_writer.add_scalar(C.tx_train_loss_reconstructor,
                                              train_rec_loss, iteration)
                    summary_writer.add_scalar(C.tx_lambda_reconstructor,
                                              reconstructor_lambda.item(),
                                              iteration)
                    summary_writer.add_scalar(C.tx_lambda, loss_lambda.item(),
                                              iteration)

            if C.use_recon:
                print(
                    "Iter {} / {} ({:.1f}%): loss {:.5f} (dec {:.5f} + rec {:.5f})"
                    .format(iteration, C.train_n_iteration,
                            iteration / C.train_n_iteration * 100, train_loss,
                            train_dec_loss, train_rec_loss))
            else:
                print("Iter {} / {} ({:.1f}%): loss {:.5f}".format(
                    iteration, C.train_n_iteration,
                    iteration / C.train_n_iteration * 100, train_loss))

            train_loss = 0
            if C.use_recon:
                train_dec_loss = 0
                train_rec_loss = 0
        """ Log Validation Progress """
        if args.debug or iteration % C.validate_every == 0:
            val_loss = 0
            val_dec_loss = 0
            val_rec_loss = 0
            gt_captions = []
            pd_captions = []
            for batch in val_data_loader:
                if C.use_recon:
                    loss, decoder_loss, decoder_output_indices, recon_loss = dec_rec_step(
                        batch,
                        decoder,
                        decoder_loss_func,
                        decoder_lambda,
                        decoder_optimizer,
                        reconstructor,
                        reconstructor_loss_func,
                        reconstructor_lambda,
                        reconstructor_optimizer,
                        loss_lambda,
                        is_train=False)
                    val_dec_loss += decoder_loss * C.batch_size
                    val_rec_loss += recon_loss * C.batch_size
                else:
                    loss, decoder_output_indices = dec_step(batch,
                                                            decoder,
                                                            decoder_loss_func,
                                                            decoder_lambda,
                                                            decoder_optimizer,
                                                            is_train=False)
                val_loss += loss * C.batch_size

                _, _, targets = batch
                gt_idxs = targets.cpu().numpy()
                pd_idxs = decoder_output_indices.cpu().numpy()
                gt_captions += convert_idxs_to_sentences(
                    gt_idxs, vocab.idx2word, vocab.word2idx['<EOS>'])
                pd_captions += convert_idxs_to_sentences(
                    pd_idxs, vocab.idx2word, vocab.word2idx['<EOS>'])

                if len(pd_captions) >= C.n_val:
                    assert len(gt_captions) == len(pd_captions)
                    gt_captions = gt_captions[:C.n_val]
                    pd_captions = pd_captions[:C.n_val]
                    break
            val_loss /= C.n_val
            val_dec_loss /= C.n_val
            val_rec_loss /= C.n_val

            if C.use_recon:
                print(
                    "[Validation] Iter {} / {} ({:.1f}%): loss {:.5f} (dec {:.5f} + rec {:5f})"
                    .format(iteration, C.train_n_iteration,
                            iteration / C.train_n_iteration * 100, val_loss,
                            val_dec_loss, val_rec_loss))
            else:
                print(
                    "[Validation] Iter {} / {} ({:.1f}%): loss {:.5f}".format(
                        iteration, C.train_n_iteration,
                        iteration / C.train_n_iteration * 100, val_loss))

            caption_pairs = [(gt, pred)
                             for gt, pred in zip(gt_captions, pd_captions)]
            caption_pairs = sample_n(caption_pairs,
                                     min(C.n_val_logs, C.batch_size))
            caption_log = "\n\n".join([
                "[GT] {}  \n[PD] {}".format(gt, pd) for gt, pd in caption_pairs
            ])

            if not args.debug:
                summary_writer.add_scalar(C.tx_val_loss, val_loss, iteration)
                if C.use_recon:
                    summary_writer.add_scalar(C.tx_val_loss_decoder,
                                              val_dec_loss, iteration)
                    summary_writer.add_scalar(C.tx_val_loss_reconstructor,
                                              val_rec_loss, iteration)
                summary_writer.add_text(C.tx_predicted_captions, caption_log,
                                        iteration)
        """ Log Test Progress """
        if not args.loss_only and (args.debug
                                   or iteration % C.test_every == 0):
            pd_vid_caption_pairs = []
            score_data_loader = MSVD.score_data_loader
            print("[Test] Iter {} / {} ({:.1f}%)".format(
                iteration, C.train_n_iteration,
                iteration / C.train_n_iteration * 100))
            for search_method in C.search_methods:
                if isinstance(search_method, str):
                    method = search_method
                    search_method_id = search_method
                if isinstance(search_method, tuple):
                    method = search_method[0]
                    search_method_id = "-".join(
                        (str(s) for s in search_method))
                scores = evaluate(C, MSVD, score_data_loader, decoder,
                                  search_method)
                score_summary = " ".join([
                    "{}: {:.3f}".format(score, scores[score])
                    for score in C.scores
                ])
                print("\t{}: {}".format(search_method_id, score_summary))
                if not args.debug:
                    for score in C.scores:
                        summary_writer.add_scalar(
                            C.tx_score[search_method_id][score], scores[score],
                            iteration)
        """ Save checkpoint """
        if iteration % C.save_every == 0:
            if not os.path.exists(C.save_dpath):
                os.makedirs(C.save_dpath)
            fpath = os.path.join(C.save_dpath,
                                 "{}_checkpoint.tar".format(iteration))

            if C.use_recon:
                torch.save(
                    {
                        'iteration': iteration,
                        'dec': decoder.state_dict(),
                        'rec': reconstructor.state_dict(),
                        'dec_opt': decoder_optimizer.state_dict(),
                        'rec_opt': reconstructor_optimizer.state_dict(),
                        'loss': loss,
                        'config': C,
                    }, fpath)
            else:
                torch.save(
                    {
                        'iteration': iteration,
                        'dec': decoder.state_dict(),
                        'dec_opt': decoder_optimizer.state_dict(),
                        'loss': loss,
                        'config': C,
                    }, fpath)

        if iteration == C.train_n_iteration:
            break
Exemplo n.º 4
0
def run(ckpt_fpath):
    checkpoint = torch.load(ckpt_fpath)
    """ Load Config """
    config = dict_to_cls(checkpoint['config'])
    """ Build Data Loader """
    if config.corpus == "MSVD":
        corpus = MSVD(config)
    elif config.corpus == "MSR-VTT":
        corpus = MSRVTT(config)
    train_iter, val_iter, test_iter, vocab = \
        corpus.train_data_loader, corpus.val_data_loader, corpus.test_data_loader, corpus.vocab
    print(
        '#vocabs: {} ({}), #words: {} ({}). Trim words which appear less than {} times.'
        .format(vocab.n_vocabs, vocab.n_vocabs_untrimmed, vocab.n_words,
                vocab.n_words_untrimmed, config.loader.min_count))
    """ Build Models """
    decoder = Decoder(rnn_type=config.decoder.rnn_type,
                      num_layers=config.decoder.rnn_num_layers,
                      num_directions=config.decoder.rnn_num_directions,
                      feat_size=config.feat.size,
                      feat_len=config.loader.frame_sample_len,
                      embedding_size=config.vocab.embedding_size,
                      hidden_size=config.decoder.rnn_hidden_size,
                      attn_size=config.decoder.rnn_attn_size,
                      output_size=vocab.n_vocabs,
                      rnn_dropout=config.decoder.rnn_dropout)
    decoder.load_state_dict(checkpoint['decoder'])

    if config.reconstructor is None:
        reconstructor = None
    else:
        if config.reconstructor.type == 'global':
            reconstructor = GlobalReconstructor(
                rnn_type=config.reconstructor.rnn_type,
                num_layers=config.reconstructor.rnn_num_layers,
                num_directions=config.reconstructor.rnn_num_directions,
                decoder_size=config.decoder.rnn_hidden_size,
                hidden_size=config.reconstructor.rnn_hidden_size,
                rnn_dropout=config.reconstructor.rnn_dropout)
        else:
            reconstructor = LocalReconstructor(
                rnn_type=config.reconstructor.rnn_type,
                num_layers=config.reconstructor.rnn_num_layers,
                num_directions=config.reconstructor.rnn_num_directions,
                decoder_size=config.decoder.rnn_hidden_size,
                hidden_size=config.reconstructor.rnn_hidden_size,
                attn_size=config.reconstructor.rnn_attn_size,
                rnn_dropout=config.reconstructor.rnn_dropout)
        reconstructor.load_state_dict(checkpoint['reconstructor'])

    model = CaptionGenerator(decoder, reconstructor,
                             config.loader.max_caption_len, vocab)
    model = model.cuda()
    '''
    """ Train Set """
    train_vid2pred = get_predicted_captions(train_iter, model, model.vocab, beam_width=5, beam_alpha=0.)
    train_vid2GTs = get_groundtruth_captions(train_iter, model.vocab)
    train_scores = score(train_vid2pred, train_vid2GTs)
    print("[TRAIN] {}".format(train_scores))

    """ Validation Set """
    val_vid2pred = get_predicted_captions(val_iter, model, model.vocab, beam_width=5, beam_alpha=0.)
    val_vid2GTs = get_groundtruth_captions(val_iter, model.vocab)
    val_scores = score(val_vid2pred, val_vid2GTs)
    print("[VAL] scores: {}".format(val_scores))
    '''
    """ Test Set """
    test_vid2pred = get_predicted_captions(test_iter,
                                           model,
                                           model.vocab,
                                           beam_width=5,
                                           beam_alpha=0.)
    test_vid2GTs = get_groundtruth_captions(test_iter, model.vocab)
    test_scores = score(test_vid2pred, test_vid2GTs)
    print("[TEST] {}".format(test_scores))

    test_save_fpath = os.path.join(C.result_dpath,
                                   "{}_{}.csv".format(config.corpus, 'test'))
    save_result(test_vid2pred, test_vid2GTs, test_save_fpath)