Ejemplo n.º 1
0
def main():

    # Train?
    do_pretrain_decoder_as_an_lm = False
    # weight_decay=0.
    weight_decay = 1e-6
    do_load_pretrained_embeddings = False
    freeze_embeddings = False

    # Load pre-trained?
    do_load_checkpoint = False
    do_load_decoder = False

    checkpoint_pathname = 'checkpoints/40_1.6245147620307074_2.6875626488429742_checkpoint.tar'
    save_dir='checkpoints'

    input_dim = 64
    vocab_size = len(WORD_LIST) # seq2seq at WORD LEVEL

    if do_load_pretrained_embeddings:
        emb_fpath='/tmpdir/pellegri/corpus/clotho-dataset/lm/word2vec_dev_128.pth'
    else:
        emb_fpath = None

    use_spec_augment = False
    use_gumbel_noise = False

    encoder_hidden_dim = 128
    embedding_dim = 128
    value_size, key_size, query_size = [64] * 3  # these could be different from embedding_dim


    # teacher_forcing_ratio = 1.
    teacher_forcing_ratio = float(sys.argv[1]) # 0.98

    # pBLSTM_time_reductions = [2, 2, 2]
    config_pBLSTM_str = sys.argv[2:]
    pBLSTM_time_reductions = [int(config_pBLSTM_str[i]) for i in range(len(config_pBLSTM_str))]
    print("config pBLSTM", pBLSTM_time_reductions)
    # nb_pBLSTM_layers = len(pBLSTM_time_reductions) # from 1 to 3
    # [2,2] 0 --> 2887375 params
    # [2,2] 8 --> 2904015 params

    decoder_hidden_size_1 = 128
    decoder_hidden_size_2 = 64

    print("use Gumbel noise", use_gumbel_noise)
    print("use teacher forcing", teacher_forcing_ratio)
    print("use SpecAugment", use_spec_augment)

    model = Seq2Seq(input_dim=input_dim, vocab_size=vocab_size, encoder_hidden_dim=encoder_hidden_dim,
                        use_spec_augment=use_spec_augment,
                        embedding_dim=embedding_dim,
                        decoder_hidden_size_1=decoder_hidden_size_1,
                        decoder_hidden_size_2=decoder_hidden_size_2, query_size=query_size,
                        value_size=value_size, key_size=key_size, isAttended=True,
                        pBLSTM_time_reductions=pBLSTM_time_reductions,
                        emb_fpath=emb_fpath, freeze_embeddings=freeze_embeddings,
                        teacher_forcing_ratio=teacher_forcing_ratio, # beam_size=beam_size, lm_weight=lm_weight,
                        word2index=word2index, return_attention_masks=False, device=DEVICE)

    print(model)

    num_params = count_parameters(model)
    print("num trainable params:", num_params)


    if do_load_checkpoint:
        print("Loading checkpoint: ", checkpoint_pathname)
        model_checkpoint = torch.load(
            checkpoint_pathname, map_location=DEVICE
        )
        model_state = model_checkpoint["model"]
        model.load_state_dict(model_state)
        model = model.to(DEVICE)

        lr = 0.001
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        optimizer_state = model_checkpoint["model_optim"]
        optimizer.load_state_dict(optimizer_state)
        start_train_epoch = model_checkpoint["iteration"]

    elif do_pretrain_decoder_as_an_lm:
        # we train the decoder weights only
        lr = 0.001
        optimizer = optim.Adam(model.decoder.parameters(), lr=lr, weight_decay=weight_decay)
        start_train_epoch = 0

    elif do_load_decoder:
        print("Loading decoder checkpoint: ", checkpoint_pathname)
        decoder_checkpoint = torch.load(
            checkpoint_pathname, map_location=DEVICE
        )
        decoder_state = decoder_checkpoint["model"]
        model.decoder.load_state_dict(decoder_state)
        model = model.to(DEVICE)

        lr = 0.001
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        optimizer_state = decoder_checkpoint["model_optim"]
        optimizer.load_state_dict(optimizer_state)
        start_train_epoch = 0 # decoder_checkpoint["iteration"]
    else:
        lr = 0.0005
        optimizer = optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
        start_train_epoch = 0

    model = model.to(DEVICE)

    # scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1)
    # scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[20], gamma=10.)
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[100], gamma=.1)

    # criterion = nn.CrossEntropyLoss(reduction=None)
    criterion = masked_ce_loss
    # nepochs = 25
    nepochs = start_train_epoch + 60
    save_every = 5

    batch_size = 64
    val_batch_size = 100

    print("nepochs", nepochs)
    print("batch_size", batch_size)
    print("val_batch_size", val_batch_size)
    print("learning rate", lr)

    model_name='seq2seq'
    corpus_name='clotho'
    params_dict = get_params_dict(model_name, corpus_name, input_dim, vocab_size, embedding_dim, value_size,
                                  pBLSTM_time_reductions, teacher_forcing_ratio, use_gumbel_noise, use_spec_augment,
                                  lr, weight_decay, emb_fpath, freeze_embeddings)

    # training data loader
    split = 'clotho_dataset_dev'
    input_field_name = 'features'
    # output_field_name = 'caption'
    # output_field_name = 'caption_ind'
    output_field_name = 'words_ind'
    # output_field_name = 'chars_ind'
    fileid_field_name = 'file_name'

    # load the whole subset in memory
    load_into_memory = False

    nb_t_steps_pad = 'max'
    has_gt_text = True
    shuffle = True
    drop_last = True
    # input_pad_at='start'
    input_pad_at = 'end'
    output_pad_at = 'end'
    num_workers = 0

    print(" training subset:", split)
    train_loader = get_clotho_loader(data_dir=data_dir_path,
                                            split=split,
                                            input_field_name=input_field_name,
                                            output_field_name=output_field_name,
                                            fileid_field_name=fileid_field_name,
                                            load_into_memory=load_into_memory,
                                            batch_size=batch_size,
                                            nb_t_steps_pad=nb_t_steps_pad,  #: Union[AnyStr, Tuple[int, int]],
                                            has_gt_text=has_gt_text,
                                            shuffle=shuffle,
                                            drop_last=drop_last,
                                            input_pad_at=input_pad_at,
                                            output_pad_at=output_pad_at,
                                            mapping_index_dict=mapping_index_dict,
                                            num_workers=num_workers)

    # validation data loader
    split = 'clotho_dataset_eva'
    # split = 'clotho_dataset_eva_50'
    shuffle = False
    drop_last = False
    val_loader = get_clotho_loader(data_dir=data_dir_path,
                                     split=split,
                                     input_field_name=input_field_name,
                                     output_field_name=output_field_name,
                                     fileid_field_name=fileid_field_name,
                                     load_into_memory=load_into_memory,
                                     batch_size=val_batch_size,
                                     nb_t_steps_pad=nb_t_steps_pad,  #: Union[AnyStr, Tuple[int, int]],
                                     has_gt_text=has_gt_text,
                                     shuffle=shuffle,
                                     drop_last=drop_last,
                                     input_pad_at=input_pad_at,
                                     output_pad_at=output_pad_at,
                                     mapping_index_dict=mapping_index_dict,
                                     num_workers=num_workers)

    # for i, test_batch in enumerate(test_loader):
    #     speech_batch, speech_lengths, ids_batch  = test_batch
    #     print(speech_batch.size(), speech_lengths, ids_batch)
    #     #     print(text_batch)
    #     #     print(text_lengths)
    #     if i == 10: break
    # return

    train_losses, val_losses = [], []
    print("Begin training...")
    for epoch in range(start_train_epoch, nepochs):

        log_fh = open('train.log', mode='at')

        print("epoch:", epoch)
        print("train subset...")
        train_loss = train(model,
                           train_loader,
                           criterion,
                           optimizer,
                           epoch,
                           pretrain_decoder=do_pretrain_decoder_as_an_lm,
                           use_gumbel_noise=use_gumbel_noise,
                           device=DEVICE)
        train_losses.append(train_loss)

        print("val subset...")
        val_loss = val(model,
                       val_loader,
                       criterion,
                       epoch,
                       pretrain_decoder=do_pretrain_decoder_as_an_lm,
                       use_gumbel_noise=False,
                       print_captions=True,
                       index2word=index2word,
                       device=DEVICE)
        val_losses.append(val_loss)

        log_fh.write("  epoch %d:\ttrain_loss: %.5f\tval_loss: %.5f\n" % (epoch, train_loss, val_loss))
        if epoch == nepochs-1: log_fh.write("val_loss: %.5f\n" %val_loss)
        log_fh.close()

        if epoch>start_train_epoch and (epoch+1) % save_every == 0:
            model_dir = save_checkpoint(save_dir,
                                        model,
                                        optimizer,
                                        epoch+1,
                                        train_loss,
                                        val_loss,
                                        params_dict,
                                        do_pretrain_decoder_as_an_lm)
        scheduler.step()

    plt.plot(train_losses, 'k', label='train')
    plt.plot(val_losses, 'b', label='val')
    plt.legend()
    plt.savefig(model_dir + "/loss.png")
Ejemplo n.º 2
0
def main():

    # weight_decay=0.
    weight_decay = 1e-6
    do_load_pretrained_embeddings = False
    freeze_embeddings = False

    # Load pre-trained?
    do_load_checkpoint = True

    do_decode_val = True
    do_decode_val_beamsearch = False

    do_plot_attention_masks_on_val = False
    decode_first_batch_only = False

    do_decode_test = False
    do_decode_test_beamsearch = False

    score_captions = False

    beam_size = 10
    # beam_size = int(sys.argv[1])
    use_lm_bigram = False
    use_lm_trigram = False

    if use_lm_bigram or use_lm_trigram:
        # lm_weight = float(sys.argv[3])
        lm_weight = 0.5
    else:
        lm_weight = 0.

    if do_decode_val_beamsearch or do_decode_test_beamsearch:
        if use_lm_bigram:
            print("beam_size:", beam_size, "LM order: 2", "lm_w:", lm_weight)
        elif use_lm_trigram:
            print("beam_size:", beam_size, "LM order: 3", "lm_w:", lm_weight)
        else:
            print("beam_size:", beam_size, "no LM")

    model_dir = "checkpoints/4367_red_2_2__128_64_0.98_False_False_0.0005_1e-06/"
    checkpoint_pathname = model_dir + '40_1.6245147620307074_2.6875626488429742_checkpoint.tar'
    save_dir = 'checkpoints'
    print("save_dir", save_dir)

    input_dim = 64
    vocab_size = len(WORD_LIST)  # seq2seq at WORD LEVEL
    if do_load_pretrained_embeddings:
        emb_fpath = '/tmpdir/pellegri/corpus/clotho-dataset/lm/word2vec_dev_128.pth'
    else:
        emb_fpath = None

    use_spec_augment = False
    use_gumbel_noise = False

    encoder_hidden_dim = 128
    embedding_dim = 128
    value_size, key_size, query_size = [
        64
    ] * 3  # these could be different from embedding_dim

    # teacher_forcing_ratio = 1.
    teacher_forcing_ratio = float(
        sys.argv[1])  # 0.98 or 1 when scoring predictions
    config_pBLSTM_str = sys.argv[2:]

    pBLSTM_time_reductions = [
        int(config_pBLSTM_str[i]) for i in range(len(config_pBLSTM_str))
    ]
    print("config pBLSTM", pBLSTM_time_reductions)
    # nb_pBLSTM_layers = len(pBLSTM_time_reductions) # from 1 to 3

    decoder_hidden_size_1 = 128
    decoder_hidden_size_2 = 64
    # [2,2] 0 --> 2887375 params
    # [2,2] 8 --> 2904015 params

    print("use Gumbel noise", use_gumbel_noise)
    print("use teacher forcing", teacher_forcing_ratio)
    print("use SpecAugment", use_spec_augment)

    if do_decode_val or do_decode_test:

        model = Seq2Seq(
            input_dim=input_dim,
            vocab_size=vocab_size,
            encoder_hidden_dim=encoder_hidden_dim,
            use_spec_augment=use_spec_augment,
            embedding_dim=embedding_dim,
            decoder_hidden_size_1=decoder_hidden_size_1,
            decoder_hidden_size_2=decoder_hidden_size_2,
            query_size=query_size,
            value_size=value_size,
            key_size=key_size,
            isAttended=True,
            pBLSTM_time_reductions=pBLSTM_time_reductions,
            emb_fpath=emb_fpath,
            freeze_embeddings=freeze_embeddings,
            teacher_forcing_ratio=
            teacher_forcing_ratio,  # beam_size=beam_size, lm_weight=lm_weight,
            word2index=word2index,
            return_attention_masks=False,
            device=DEVICE)

    elif do_decode_val_beamsearch or do_decode_test_beamsearch:
        print("Beam decoding w")
        if use_lm_bigram:
            print(" using 2g LM with lm_w=%.3f" % (lm_weight))
        elif use_lm_trigram:
            print(" using 3g LM with lm_w=%.3f" % (lm_weight))
        else:
            print(" not using LM")
        print(" bs=", beam_size)

        model = BeamSeq2Seq(input_dim=input_dim,
                            vocab_size=vocab_size,
                            encoder_hidden_dim=encoder_hidden_dim,
                            use_spec_augment=use_spec_augment,
                            embedding_dim=embedding_dim,
                            decoder_hidden_size_1=decoder_hidden_size_1,
                            decoder_hidden_size_2=decoder_hidden_size_2,
                            query_size=query_size,
                            value_size=value_size,
                            key_size=key_size,
                            isAttended=True,
                            pBLSTM_time_reductions=pBLSTM_time_reductions,
                            teacher_forcing_ratio=teacher_forcing_ratio,
                            beam_size=beam_size,
                            use_lm_bigram=use_lm_bigram,
                            use_lm_trigram=use_lm_trigram,
                            lm_weight=lm_weight,
                            word2index=word2index,
                            index2word=index2word,
                            vocab=WORD_LIST,
                            return_attention_masks=False,
                            device=DEVICE)

    print(model)

    num_params = count_parameters(model)
    print("num trainable params:", num_params)

    if do_load_checkpoint:
        print("Loading checkpoint: ", checkpoint_pathname)
        model_checkpoint = torch.load(checkpoint_pathname, map_location=DEVICE)
        model_state = model_checkpoint["model"]
        model.load_state_dict(model_state)
        model = model.to(DEVICE)

        start_train_epoch = model_checkpoint["iteration"]

    model = model.to(DEVICE)

    criterion = masked_ce_loss
    nepochs = start_train_epoch

    if do_decode_val_beamsearch or do_decode_test_beamsearch:
        val_batch_size = 1
    else:
        val_batch_size = 100

    print("nepochs", nepochs)
    print("batch_size", val_batch_size)

    model_name = 'seq2seq'
    corpus_name = 'clotho'
    lr = 0
    params_dict = get_params_dict(model_name, corpus_name, input_dim,
                                  vocab_size, embedding_dim, value_size,
                                  pBLSTM_time_reductions,
                                  teacher_forcing_ratio, use_gumbel_noise,
                                  use_spec_augment, lr, weight_decay,
                                  emb_fpath, freeze_embeddings)

    split = 'clotho_dataset_dev'
    input_field_name = 'features'
    # output_field_name = 'caption'
    # output_field_name = 'caption_ind'
    output_field_name = 'words_ind'
    # output_field_name = 'chars_ind'
    fileid_field_name = 'file_name'

    #!!!! change to True
    load_into_memory = True

    nb_t_steps_pad = 'max'
    has_gt_text = True
    shuffle = False
    drop_last = False
    # input_pad_at='start'
    input_pad_at = 'end'
    output_pad_at = 'end'
    num_workers = 0

    if do_decode_val or do_decode_val_beamsearch:
        split = 'clotho_dataset_eva'
        if score_captions:
            has_gt_text = False
        val_loader = get_clotho_loader(
            data_dir=data_dir_path,
            split=split,
            input_field_name=input_field_name,
            output_field_name=output_field_name,
            fileid_field_name=fileid_field_name,
            load_into_memory=load_into_memory,
            batch_size=val_batch_size,
            nb_t_steps_pad=nb_t_steps_pad,  #: Union[AnyStr, Tuple[int, int]],
            has_gt_text=has_gt_text,
            shuffle=shuffle,
            drop_last=drop_last,
            input_pad_at=input_pad_at,
            output_pad_at=output_pad_at,
            mapping_index_dict=mapping_index_dict,
            num_workers=num_workers)

    if do_decode_test or do_decode_test_beamsearch:
        split = 'clotho_dataset_test'
        has_gt_text = False

        test_loader = get_clotho_loader(
            data_dir=data_dir_path,
            split=split,
            input_field_name=input_field_name,
            output_field_name=output_field_name,
            fileid_field_name=fileid_field_name,
            load_into_memory=load_into_memory,
            batch_size=val_batch_size,
            nb_t_steps_pad=nb_t_steps_pad,  #: Union[AnyStr, Tuple[int, int]],
            has_gt_text=has_gt_text,
            shuffle=shuffle,
            drop_last=drop_last,
            input_pad_at=input_pad_at,
            output_pad_at=output_pad_at,
            mapping_index_dict=mapping_index_dict,
            num_workers=num_workers)

    if do_decode_val:
        print("decoding val subset GREEEDY SEARCH...")
        result_fpath = 'results_decode_val_greedy.txt'
        result_fh = open(result_fpath, "at")

        if do_plot_attention_masks_on_val:
            att_masks, first_batch_text, first_batch_preds_char = decode_val(
                model,
                val_loader,
                criterion,
                index2word,
                word2index,
                decode_first_batch_only=decode_first_batch_only,
                use_gumbel_noise=False,
                plot_att=True,
                device=DEVICE)
            is_already_text = True
            plot_att_masks_to_png_files(att_masks, first_batch_preds_char,
                                        is_already_text, index2word,
                                        word2index, save_dir, model_dir,
                                        params_dict)
        else:
            if not score_captions:
                captions_pred, captions_gt_indices, all_ids_str = decode_val(
                    model,
                    val_loader,
                    criterion,
                    index2word,
                    word2index,
                    decode_first_batch_only=decode_first_batch_only,
                    use_gumbel_noise=False,
                    plot_att=False,
                    device=DEVICE)

                captions_gt = index2words(captions_gt_indices, index2word)
                #
                captions_pred_every_five = captions_pred[::5]
                all_ids_str_every_five = all_ids_str[::5]
                # save_gt_captions(data_dir + "/clotho_captions_evaluation.pkl", captions_gt, all_ids_str_every_five)
                # save_gt_captions(data_dir + "/clotho_captions_evaluation_50.pkl", captions_gt, all_ids_str_every_five)

                # gt_file = "/clotho_captions_evaluation.pkl"
                # print("GT CAPTION FILE:", data_dir +  gt_file)
                # captions_gt = load_gt_captions(data_dir + gt_file, all_ids_str_every_five)

                print("captions_gt_indices", len(captions_gt_indices))
                print("captions_pred", len(captions_pred))

                print("captions_gt", len(captions_gt))
                print("captions_pred_every_five",
                      len(captions_pred_every_five))
                print("file ids every_five", len(all_ids_str_every_five))

                out_csv_fpath = model_dir + "/val_predicted_captions_greedy_NEW.csv"
                write_csv_prediction_file(captions_pred_every_five,
                                          all_ids_str_every_five,
                                          out_csv_fpath)

                metrics = evaluate_metrics_from_lists(captions_pred_every_five,
                                                      captions_gt)

                average_metrics = metrics[0]
                print("\n")
                for m in average_metrics.keys():
                    print("%s\t%.3f" % (m, average_metrics[m]))
                result_fh.write("%s\t%.3f\n" %
                                ('SPIDEr', average_metrics['SPIDEr']))

                result_fh.write(
                    "%s,%.3f,%s,%s\n" %
                    ("_".join(config_pBLSTM_str), average_metrics['SPIDEr'],
                     checkpoint_pathname, emb_fpath))

                result_fh.close()

            else:
                pred_fpath = 'checkpoints/seq2seq/clotho/best_model/4367_red_2_2__128_64_0.98_False_False_0.0005_1e-06//val_predicted_captions_beamsearch_nolm_bs25_alpha_12.csv'

                wav_id_list, captions_dict_pred = read_csv_prediction_file(
                    pred_fpath)
                print(wav_id_list[0], captions_dict_pred[wav_id_list[0]])

                criterion = masked_ce_loss_per_utt
                test_losses, all_ids_str = score_test_captions(
                    model,
                    criterion,
                    val_loader,
                    captions_dict_pred,
                    index2word,
                    word2index,
                    use_gumbel_noise=False,
                    device=DEVICE)
                csv_out_fpath = 'checkpoints/seq2seq/clotho/best_model/4367_red_2_2__128_64_0.98_False_False_0.0005_1e-06//val_predicted_captions_beamsearch_nolm_bs25_alpha_12_scores.csv'

                with open(csv_out_fpath, "wt") as fh:
                    for ind_wav, wav_id in enumerate(all_ids_str):
                        fh.write("%s,%f\n" % (wav_id, test_losses[ind_wav]))

    elif do_decode_val_beamsearch:

        print("decoding val subset BEAM SEARCH...")

        result_fpath = 'results_decode_val_beamsearch.txt'
        result_fh = open(result_fpath, "at")
        captions_pred, captions_gt_indices, all_ids_str = bs_decode_val(
            model,
            val_loader,
            index2word,
            use_gumbel_noise=use_gumbel_noise,
            device=DEVICE)

        captions_pred_every_five = captions_pred[::5]
        all_ids_str_every_five = all_ids_str[::5]
        # captions_pred_every_five = captions_pred
        # all_ids_str_every_five = all_ids_str

        # gt_file = "/clotho_captions_evaluation.pkl"
        # print("GT CAPTION FILE:", data_dir + gt_file)
        # captions_gt = load_gt_captions(data_dir + gt_file, all_ids_str_every_five)

        captions_gt = index2words(captions_gt_indices, index2word)

        print("captions_gt_indices", len(captions_gt_indices))
        print("captions_gt", len(captions_gt))
        print("captions_pred", len(captions_pred))

        print("captions_pred_every_five", len(captions_pred_every_five))
        print("file ids every_five", len(all_ids_str_every_five))

        print("\n")
        if use_lm_bigram:
            out_csv_fpath = model_dir + "/val_predicted_captions_beamsearch_lm_%.2f_2g.csv" % lm_weight
        elif use_lm_trigram:
            out_csv_fpath = model_dir + "/val_predicted_captions_beamsearch_lm_%.2f_3g.csv" % lm_weight
        else:
            out_csv_fpath = model_dir + "/val_predicted_captions_beamsearch_nolm_bs%d_alpha_12.csv" % (
                beam_size)

        write_csv_prediction_file(captions_pred_every_five,
                                  all_ids_str_every_five, out_csv_fpath)

        if not decode_first_batch_only:

            metrics = evaluate_metrics_from_lists(captions_pred_every_five,
                                                  captions_gt)
            print("\n")
            average_metrics = metrics[0]
            for m in average_metrics.keys():
                # print("%s\t%.3f" % (m, average_metrics[m]))
                print("%.3f" % (average_metrics[m]))

            result_fh.write(
                "%s,%d,%.3f,%s,%.2f\n" %
                ("_".join(config_pBLSTM_str), n_attn_heads,
                 average_metrics['SPIDEr'], checkpoint_pathname, lm_weight))
            result_fh.close()

    elif do_decode_test:
        if not score_captions:
            print("decoding test subset (greedy)...")
            captions_pred, all_ids_str = decode_test(model,
                                                     test_loader,
                                                     index2word,
                                                     use_gumbel_noise=False,
                                                     device=DEVICE)

            print("captions_pred", len(captions_pred))
            out_csv_fpath = model_dir + "/test_predicted_captions_greedy.csv"
            write_csv_prediction_file(captions_pred, all_ids_str,
                                      out_csv_fpath)
        else:
            pred_fpath = '../dcase2020_challenge_submission_task6_thomas_pellegrini/Pellegrini_IRIT_task6_3/test_predicted_captions_beamsearch_lm_0.50_2g.csv'

            wav_id_list, captions_dict_pred = read_csv_prediction_file(
                pred_fpath)
            print(wav_id_list[0], captions_dict_pred[wav_id_list[0]])

            criterion = masked_ce_loss_per_utt
            test_losses, all_ids_str = score_test_captions(
                model,
                criterion,
                test_loader,
                captions_dict_pred,
                index2word,
                word2index,
                use_gumbel_noise=False,
                device=DEVICE)
            csv_out_fpath = '../dcase2020_challenge_submission_task6_thomas_pellegrini/Pellegrini_IRIT_task6_3/scores_per_utt_sub3.csv'

            with open(csv_out_fpath, "wt") as fh:
                for ind_wav, wav_id in enumerate(all_ids_str):
                    fh.write("%s,%f\n" % (wav_id, test_losses[ind_wav]))

    elif do_decode_test_beamsearch:

        captions_pred, all_ids_str = bs_decode_test(model,
                                                    test_loader,
                                                    index2word,
                                                    use_gumbel_noise=False,
                                                    device=DEVICE)
        print("test captions_pred", len(captions_pred))

        print("\n")
        if use_lm_bigram:
            out_csv_fpath = model_dir + "/test_predicted_captions_beamsearch_lm_%.2f_2g.csv" % lm_weight
        elif use_lm_trigram:
            out_csv_fpath = model_dir + "/test_predicted_captions_beamsearch_lm_%.2f_3g.csv" % lm_weight
        else:
            out_csv_fpath = model_dir + "/test_predicted_captions_beamsearch_nolm_bs25_alpha12.csv"

        write_csv_prediction_file(captions_pred, all_ids_str, out_csv_fpath)
Ejemplo n.º 3
0
def train(gpu, args, queue):
    worker_configurer(queue)
    rank = args.nr * args.gpus + gpu
    dist.init_process_group(backend='nccl',
                            init_method='env://',
                            world_size=args.world_size,
                            rank=rank)
    logger = logging.getLogger(f'worker-{rank}')
    train_loader, test_loader, train_sampler, test_sampler = get_dataloader(
        args, rank)
    if rank == 0:
        logger.info(
            f"Train size: {len(train_loader.dataset)}, test size: {len(test_loader.dataset)}"
        )
    n_batch = len(train_loader)
    logger.info(
        f"Train size: {len(train_loader)}, test size: {len(test_loader)}")
    torch.manual_seed(args.seed + rank * 123)
    torch.cuda.set_device(gpu)
    model, optimizer, scheduler = get_model(args, gpu)
    if rank == 0:
        count_parameters(model, logger)
        logger.info(str(model))
    checkpoint_dir = os.path.join(args.root, 'checkpoint.pt')
    if args.model.startswith('Sto'):
        model.train()
        for i in range(args.num_epochs):
            train_sampler.set_epoch(i)
            for bx, by in train_loader:
                bx = bx.cuda(non_blocking=True)
                by = by.cuda(non_blocking=True)
                loglike, kl = vb_loss(model, bx, by, args.num_sample['train'])
                klw = get_kl_weight(i, args)
                loss = loglike + klw * kl / (n_batch * args.total_batch_size)
                optimizer.zero_grad()
                loss.backward()
                optimizer.step()
            scheduler.step()
            if (i + 1) % args.logging_freq == 0:
                logger.info(
                    "VB Epoch %d: loglike: %.4f, kl: %.4f, kl weight: %.4f, lr1: %.4f, lr2: %.4f",
                    i, loglike.item(), kl.item(), klw,
                    optimizer.param_groups[0]['lr'],
                    optimizer.param_groups[1]['lr'])
            if (i + 1) % args.test_freq == 0:
                if rank == 0:
                    torch.save(model.module.state_dict(), checkpoint_dir)
                    logger.info('Save checkpoint')
                with torch.no_grad():
                    nll, acc = test_nll(model, test_loader,
                                        args.num_sample['test'])


#                 if rank == 0:
                logger.info("VB Epoch %d: test NLL %.4f, acc %.4f", i, nll,
                            acc)
                model.train()
        if rank == 0:
            torch.save(model.module.state_dict(), checkpoint_dir)
            logger.info('Save checkpoint')
        tnll = 0
        acc = 0
        nll_miss = 0
        model.eval()
        with torch.no_grad():
            for bx, by in test_loader:
                bx = bx.cuda(non_blocking=True)
                by = by.cuda(non_blocking=True)
                indices = torch.empty(bx.size(0) * args.num_sample['test'],
                                      dtype=torch.long,
                                      device=bx.device)
                prob = torch.cat([
                    model(bx,
                          args.num_sample['test'],
                          indices=torch.full(
                              (bx.size(0) * args.num_sample['test'], ),
                              idx,
                              out=indices,
                              device=bx.device,
                              dtype=torch.long))
                    for idx in range(model.module.n_components)
                ],
                                 dim=1)
                y_target = by.unsqueeze(1).expand(
                    -1, args.num_sample['test'] * model.module.n_components)
                bnll = D.Categorical(logits=prob).log_prob(y_target)
                bnll = torch.logsumexp(bnll, dim=1) - torch.log(
                    torch.tensor(
                        args.num_sample['test'] * model.module.n_components,
                        dtype=torch.float32,
                        device=bnll.device))
                tnll -= bnll.sum().item()
                vote = prob.exp().mean(dim=1)
                pred = vote.argmax(dim=1)

                y_miss = pred != by
                if y_miss.sum().item() > 0:
                    nll_miss -= bnll[y_miss].sum().item()
                acc += (pred == by).sum().item()
        nll_miss /= len(test_loader.dataset) - acc
        tnll /= len(test_loader.dataset)
        acc /= len(test_loader.dataset)
        #if rank == 0:
        logger.info("Test data: acc %.4f, nll %.4f, nll miss %.4f" %
                    (acc, tnll, nll_miss))
    logger.info(END_MSG)
Ejemplo n.º 4
0
                lossListnbTr.append(pctg_correctTr.item())
                # print("     \t  pctg of wrong train samples  : %.4f"%(pctg_correctTr))

                pctg_correctEv = percentageCorrect(
                    model(test_input, is_vis=True), test_target)
                lossListnbEv.append(pctg_correctEv.item())
                # print("     \t  pctg of wrong test samples   : %.4f"%(pctg_correctEv))

        allLoss_CETr[name].append(lossListCETr)
        allLoss_nbTr[name].append(lossListnbTr)
        allLoss_nbEv[name].append(lossListnbEv)

    print(
        "\nfor {%s} with %d parameters:\n "
        "--average percentage of correct test samples    : %.4f ( std=%.4f )\n\n"
        % (name, count_parameters(model), torch.tensor(
            allLoss_nbEv[name])[:, -1].mean(), torch.tensor(
                allLoss_nbEv[name])[:, -1].std()))

    avgInAllRnd[name] = torch.tensor(allLoss_nbEv[name]).mean(dim=0)

    # plt.figure()
    # plt.plot(torch.tensor(allLoss_nbEv[name]).t())
    # plt.title(name)

    # plt.figure()
    # plt.plot(avgInAllRnd["benchmark_fc"]);
    # plt.plot(avgInAllRnd["benchmark_cnn"]);
    # plt.plot(avgInAllRnd['parallelConv'])
    # plt.plot(avgInAllRnd['siamese'])
    # plt.legend(trainList)
Ejemplo n.º 5
0
    dataset = 'imdb'
    lang, lines = cachePrepareData(dataset)

    model_filename = ''.join([
        './pretrained/pretrained_lstm_', dataset, '_',
        str(hidden_size), '_',
        str(train_iters), '.pt'
    ])

    # w2v_model = Word2Vec.load(''.join(["word2vec_", str(hidden_size), ".model"]))
    w2v_vectors = get_embeddings(lang, lines, hidden_size)
    w2v_vectors = torch.from_numpy(w2v_vectors).float()
    lstm = pretrainLSTM(lang.n_words, hidden_size).to(device)
    print('lstm initialized')
    print("Total number of trainable parameters without word2vec:",
          count_parameters(lstm))

    def copy_embedding(layer, vectors):
        return nn.Embedding(layer.num_embeddings,
                            layer.embedding_dim).from_pretrained(vectors)

    lstm.embedding = copy_embedding(lstm.embedding, w2v_vectors)
    print("Total number of trainable parameters with word2vec:",
          count_parameters(lstm))

    print('using hidden_size=' + str(hidden_size), ' train_iters = ',
          train_iters)
    trainIters(lstm,
               lang,
               lines,
               train_iters,
Ejemplo n.º 6
0
                                       (train_target.narrow(0, batch, batchSize),
                                       train_class.narrow(0, batch, batchSize)[:,0],
                                       train_class.narrow(0, batch, batchSize)[:,1])
                                       )
                     
                     optimizer.zero_grad()
                     loss.backward()
                     optimizer.step()
                     
                     lossListCETr.append(loss.detach())
                     # print("------ batch %6d, CrossEntropy loss : %.4f"%(batch, loss.detach()))
                 
                     pctg_correctTr = percentageCorrect( model(train_input, is_vis=True), train_target )
                     lossListnbTr.append(pctg_correctTr.item())
                     # print("     \t  pctg of wrong train samples  : %.4f"%(pctg_correctTr))
                     
                     pctg_correctEv = percentageCorrect( model(test_input, is_vis=True), test_target )
                     lossListnbEv.append(pctg_correctEv.item())    
                     # print("     \t  pctg of wrong test samples   : %.4f"%(pctg_correctEv))
                     
             allLoss_CETr[name][str(lr)+'_'+str(batchSize)].append(lossListCETr)
             allLoss_nbTr[name][str(lr)+'_'+str(batchSize)].append(lossListnbTr)
             allLoss_nbEv[name][str(lr)+'_'+str(batchSize)].append(lossListnbEv)
             
         avgInAllRnd[name][str(lr)+'_'+str(batchSize)] = torch.tensor(allLoss_nbEv[name][str(lr)+'_'+str(batchSize)]).mean(dim=0)
         
         print("\nfor {%s} with %d parameters, lr = %.e, batchsize = %d:\n " 
               "--average percentage of correct test samples (final epoch)    : %.4f \n\n"
               %(name, count_parameters(model), lr, batchSize,
                 avgInAllRnd[name][str(lr)+'_'+str(batchSize)][-1]))
 
Ejemplo n.º 7
0
def main(_run, model_name, num_train_sample, num_test_sample, device,
         validation, validate_freq, num_epochs, logging_freq):
    logger = get_logger()
    if validation:
        train_loader, valid_loader, test_loader = get_dataloader()
        logger.info(
            f"Train size: {len(train_loader.dataset)}, validation size: {len(valid_loader.dataset)}, test size: {len(test_loader.dataset)}"
        )
    else:
        train_loader, test_loader = get_dataloader()
        logger.info(
            f"Train size: {len(train_loader.dataset)}, test size: {len(test_loader.dataset)}"
        )
    n_batch = len(train_loader)
    model, optimizer, scheduler = get_model()
    count_parameters(model, logger)
    logger.info(str(model))
    checkpoint_dir = os.path.join(BASE_DIR, _run._id, 'checkpoint.pt')
    if model_name.startswith('Sto'):
        model.train()
        best_nll = float('inf')
        for i in range(num_epochs):
            for bx, by in train_loader:
                bx = bx.to(device)
                by = by.to(device)
                optimizer.zero_grad()
                loglike, kl = model.vb_loss(bx, by, num_train_sample)
                klw = get_kl_weight(epoch=i)
                loss = loglike + klw * kl / (n_batch * bx.size(0))
                loss.backward()
                optimizer.step()
                ex.log_scalar('loglike.train', loglike.item(), i)
                ex.log_scalar('kl.train', kl.item(), i)
            scheduler.step()
            if (i + 1) % logging_freq == 0:
                logger.info(
                    "VB Epoch %d: loglike: %.4f, kl: %.4f, kl weight: %.4f, lr1: %.4f, lr2: %.4f",
                    i, loglike.item(), kl.item(), klw,
                    optimizer.param_groups[0]['lr'],
                    optimizer.param_groups[1]['lr'])
            if (i + 1) % validate_freq == 0:
                if validation:
                    with torch.no_grad():
                        nll, acc = test_nll(model, valid_loader)
                    if best_nll >= nll:
                        best_nll = nll
                        torch.save(model.state_dict(), checkpoint_dir)
                        logger.info('Save checkpoint')
                    ex.log_scalar('nll.valid', nll, i)
                    ex.log_scalar('acc.valid', acc, i)
                    logger.info("VB Epoch %d: validation NLL %.4f, acc %.4f",
                                i, nll, acc)
                else:
                    torch.save(model.state_dict(), checkpoint_dir)
                    logger.info('Save checkpoint')
                with torch.no_grad():
                    nll, acc = test_nll(model, test_loader)
                ex.log_scalar('nll.test', nll, i)
                ex.log_scalar('acc.test', acc, i)
                logger.info("VB Epoch %d: test NLL %.4f, acc %.4f", i, nll,
                            acc)
                model.train()
        model.load_state_dict(torch.load(checkpoint_dir, map_location=device))
        tnll = 0
        acc = 0
        nll_miss = 0
        model.eval()
        with torch.no_grad():
            for bx, by in test_loader:
                bx = bx.to(device)
                by = by.to(device)
                indices = torch.empty(bx.size(0) * num_test_sample,
                                      dtype=torch.long,
                                      device=bx.device)
                prob = torch.cat([
                    model.forward(bx,
                                  num_test_sample,
                                  indices=torch.full(
                                      (bx.size(0) * num_test_sample, ),
                                      idx,
                                      out=indices,
                                      device=bx.device,
                                      dtype=torch.long))
                    for idx in range(model.n_components)
                ],
                                 dim=1)
                y_target = by.unsqueeze(1).expand(
                    -1, num_test_sample * model.n_components)
                bnll = D.Categorical(logits=prob).log_prob(y_target)
                bnll = torch.logsumexp(bnll, dim=1) - torch.log(
                    torch.tensor(num_test_sample * model.n_components,
                                 dtype=torch.float32,
                                 device=bnll.device))
                tnll -= bnll.sum().item()
                vote = prob.exp().mean(dim=1)
                pred = vote.argmax(dim=1)

                y_miss = pred != by
                if y_miss.sum().item() > 0:
                    nll_miss -= bnll[y_miss].sum().item()
                acc += (pred == by).sum().item()
        nll_miss /= len(test_loader.dataset) - acc
        tnll /= len(test_loader.dataset)
        acc /= len(test_loader.dataset)
        logger.info("Test data: acc %.4f, nll %.4f, nll miss %.4f", acc, tnll,
                    nll_miss)
    elif model_name.startswith('Bayesian'):
        model.train()
        best_nll = float('inf')
        for i in range(num_epochs):
            for bx, by in train_loader:
                bx = bx.to(device)
                by = by.to(device)
                optimizer.zero_grad()
                loglike, kl = model.vb_loss(bx, by, num_train_sample)
                klw = get_kl_weight(epoch=i)
                loss = loglike + klw * kl / (n_batch * bx.size(0))
                loss.backward()
                optimizer.step()
                ex.log_scalar('loglike.train', loglike.item(), i)
                ex.log_scalar('kl.train', kl.item(), i)
            scheduler.step()
            if (i + 1) % logging_freq == 0:
                logger.info(
                    "VB Epoch %d: loglike: %.4f, kl: %.4f, kl weight: %.4f, lr: %.4f",
                    i, loglike.item(), kl.item(), klw,
                    optimizer.param_groups[0]['lr'])
            if (i + 1) % validate_freq == 0:
                if validation:
                    with torch.no_grad():
                        nll, acc = test_nll(model, valid_loader)
                    if best_nll >= nll:
                        best_nll = nll
                        torch.save(model.state_dict(), checkpoint_dir)
                        logger.info('Save checkpoint')
                    ex.log_scalar('nll.valid', nll, i)
                    ex.log_scalar('acc.valid', acc, i)
                    logger.info("VB Epoch %d: validation NLL %.4f, acc %.4f",
                                i, nll, acc)
                else:
                    torch.save(model.state_dict(), checkpoint_dir)
                    logger.info('Save checkpoint')
                with torch.no_grad():
                    nll, acc = test_nll(model, test_loader)
                ex.log_scalar('nll.test', nll, i)
                ex.log_scalar('acc.test', acc, i)
                logger.info("VB Epoch %d: test NLL %.4f, acc %.4f", i, nll,
                            acc)
                model.train()
        model.load_state_dict(torch.load(checkpoint_dir, map_location=device))
        tnll = 0
        acc = 0
        nll_miss = 0
        model.eval()
        with torch.no_grad():
            for bx, by in test_loader:
                bx = bx.to(device)
                by = by.to(device)
                prob = model.forward(bx, num_test_sample)
                y_target = by.unsqueeze(1).expand(-1, num_test_sample)
                bnll = D.Categorical(logits=prob).log_prob(y_target)
                bnll = torch.logsumexp(bnll, dim=1) - torch.log(
                    torch.tensor(num_test_sample,
                                 dtype=torch.float32,
                                 device=bnll.device))
                tnll -= bnll.sum().item()
                vote = prob.exp().mean(dim=1)
                pred = vote.argmax(dim=1)

                y_miss = pred != by
                if y_miss.sum().item() > 0:
                    nll_miss -= bnll[y_miss].sum().item()
                acc += (pred == by).sum().item()
        nll_miss /= len(test_loader.dataset) - acc
        tnll /= len(test_loader.dataset)
        acc /= len(test_loader.dataset)
        logger.info("Test data: acc %.4f, nll %.4f, nll miss %.4f", acc, tnll,
                    nll_miss)
    elif model_name.startswith('Det'):
        model.train()
        best_nll = float('inf')
        for i in range(num_epochs):
            for bx, by in train_loader:
                optimizer.zero_grad()
                bx = bx.to(device)
                by = by.to(device)
                pred = model(bx)
                loss = torch.nn.functional.nll_loss(pred, by)
                loss.backward()
                optimizer.step()
                ex.log_scalar("nll.train", loss.item(), i)
            scheduler.step()
            if (i + 1) % logging_freq == 0:
                logger.info("Epoch %d: train %.4f, lr %.4f", i, loss.item(),
                            optimizer.param_groups[0]['lr'])
            if (i + 1) % validate_freq == 0:
                model.eval()
                if validation:
                    with torch.no_grad():
                        nll = 0
                        acc = 0
                        for bx, by in valid_loader:
                            bx = bx.to(device)
                            by = by.to(device)
                            pred = model(bx)
                            nll += torch.nn.functional.nll_loss(
                                pred, by).item() * len(by)
                            acc += (pred.argmax(1) == by).sum().item()
                        nll /= len(valid_loader.dataset)
                        acc /= len(valid_loader.dataset)
                    if best_nll >= nll:
                        best_nll = nll
                        torch.save(model.state_dict(), checkpoint_dir)
                        logger.info('Save checkpoint')
                    ex.log_scalar('nll.valid', nll, i)
                    ex.log_scalar('acc.valid', acc, i)
                    logger.info("Epoch %d: validation %.4f, %.4f", i, nll, acc)
                else:
                    torch.save(model.state_dict(), checkpoint_dir)
                    logger.info('Save checkpoint')
                with torch.no_grad():
                    nll = 0
                    acc = 0
                    for bx, by in test_loader:
                        bx = bx.to(device)
                        by = by.to(device)
                        pred = model(bx)
                        nll += torch.nn.functional.nll_loss(
                            pred, by).item() * len(by)
                        acc += (pred.argmax(1) == by).sum().item()
                    nll /= len(test_loader.dataset)
                    acc /= len(test_loader.dataset)
                ex.log_scalar('nll.test', nll, i)
                ex.log_scalar('acc.test', acc, i)
                logger.info("Epoch %d: test %.4f, acc %.4f", i, nll, acc)
                model.train()
        model.load_state_dict(torch.load(checkpoint_dir, map_location=device))
        tnll = 0
        acc = 0
        nll_miss = 0
        model.eval()
        with torch.no_grad():
            for bx, by in test_loader:
                bx = bx.to(device)
                by = by.to(device)
                prob = model(bx)
                pred = prob.argmax(dim=1)
                tnll += torch.nn.functional.nll_loss(prob, by).item() * len(by)
                y_miss = pred != by
                if y_miss.sum().item() > 0:
                    prob_miss = prob[y_miss]
                    by_miss = by[y_miss]
                    nll_miss += torch.nn.functional.nll_loss(
                        prob_miss, by_miss).item() * len(by_miss)
                acc += (pred == by).sum().item()
        nll_miss /= len(test_loader.dataset) - acc
        tnll /= len(test_loader.dataset)
        acc /= len(test_loader.dataset)
        logger.info("Test data: acc %.4f, nll %.4f, nll miss %.4f", acc, tnll,
                    nll_miss)