def main(hps):
    
    #pre_train
    if args.pre_train:
        word2idx, idx2word, verb2idx, idx2verb = make_vocab(hps)
#         word2idx, idx2word, verb2idx, idx2verb = load_vocab(hps)
#         mapping, vectors = load_glove(hps)
#         weights_matrix = make_pre_trained_word_embedding(mapping, vectors, word2idx.keys(), hps)
        hps = hps._replace(vocab_size=len(word2idx))
        hps = hps._replace(verb_vocab_size=len(verb2idx))
        hps = hps._replace(pre_train=True)
        
        print('parameters:')
        print(hps)
        
        train_loader, valid_loader, char_weights, action_weights = load_train_data(word2idx, verb2idx, hps)
        model = Model(char_weights, action_weights, hps)
#         model.load_state_dict(torch.load(hps.test_path + 'models/pre_train.model'))
        model = model.cuda()
        optimizer = optim.Adam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay)
        print('pre_training', flush=True)
        pre_train(model, optimizer, train_loader, valid_loader, idx2word, hps)
        
    #train
    elif args.train:
#         word2idx, idx2word = make_vocab(hps)
        word2idx, idx2word = load_vocab(hps)
        hps = hps._replace(vocab_size=len(word2idx))
        
            
        print('parameters:')
        print(hps)
        
        train_loader, valid_loader, char_weights, action_weights = load_train_data(word2idx, hps)
        model = Model(char_weights, action_weights, hps)
#         model.load_state_dict(torch.load(hps.test_path + 'models/best.model'))
        if args.reload:
            model.load_state_dict(torch.load(hps.test_path + hps.save_path.format(args.reload_epoch)))

        model.cuda()
#         model = nn.DataParallel(model,device_ids=[0])
        optimizer = optim.Adam(model.parameters(), lr=hps.lr, weight_decay=hps.weight_decay)
        print('training', flush=True)
        train(model, optimizer, train_loader, valid_loader, idx2word, hps)
    
    #test
    if args.test:
        print('testing', flush=True)
        word2idx, idx2word = load_vocab(hps)
        hps = hps._replace(vocab_size=len(word2idx))
        hps = hps._replace(test=True)
        model = Model([0] * hps.max_num_char, [0] * hps.vocab_size, hps)
        model.load_state_dict(torch.load(hps.test_path + hps.save_path))
        model.cuda()
        test_loader, anony2names = load_test_data(word2idx, hps)
        test(model, test_loader, idx2word, anony2names, hps)
예제 #2
0
    def test_make_vocab(self):
        tokens = [
            'Rock', 'n', 'Roll', 'is', 'a', 'risk', '.', 'You', 'rick',
            'being', 'ridiculed', '.'
        ]
        token_to_index, index_to_token = make_vocab(tokens, 1, 10)

        self.assertEqual(token_to_index['<pad>'], 0)
        self.assertEqual(token_to_index['<unk>'], 1)
        self.assertEqual(token_to_index['<s>'], 2)
        self.assertEqual(token_to_index['</s>'], 3)
        self.assertEqual(len(token_to_index), 10)
        self.assertEqual(len(index_to_token), 10)
        self.assertEqual(index_to_token[0], '<pad>')
        self.assertEqual(index_to_token[1], '<unk>')
        self.assertEqual(index_to_token[2], '<s>')
        self.assertEqual(index_to_token[3], '</s>')
예제 #3
0
def main(args):
    if not exists(join(__SAVE_PATH, args.dir)):
        os.makedirs(join(__SAVE_PATH, args.dir))
    os.makedirs(join(__SAVE_PATH, '{}/ckpt'.format(args.dir)))
    word2id, id2word = make_vocab(args.vsize)
    with open(join(__SAVE_PATH, join(args.dir, 'vocab.pkl')), 'wb') as f:
        pkl.dump((word2id, id2word), f, pkl.HIGHEST_PROTOCOL)
    word2id = defaultdict(lambda: UNK, word2id)
    train_loader = get_coco_train_loader(word2id,
                                         args.max_len,
                                         args.batch_size,
                                         cuda=args.cuda)
    val_loader = get_coco_val_loader(word2id,
                                     args.max_len,
                                     args.batch_size,
                                     cuda=args.cuda)

    model = AttnImCap(len(id2word), args.emb_dim, args.n_cell, args.n_layer)
    if args.emb:
        emb, oovs = load_embedding_from_bin(args.emb, id2word)
        model.set_embedding(emb, oovs=oovs)
    if args.cuda:
        model.cuda()

    if args.opt == 'adam':
        opt_cls = optim.Adam
    else:
        raise ValueError()
    opt_kwargs = {'lr': args.lr}  # TODO
    optimizer = opt_cls(model.parameters(), **opt_kwargs)
    scheduler = ReduceLROnPlateau(optimizer,
                                  mode='min',
                                  patience=0,
                                  factor=0.5,
                                  verbose=True)

    meta = vars(args)
    with open(join(__SAVE_PATH, '{}/meta.json'.format(args.dir)), 'w') as f:
        json.dump(meta, f)
    configure(join(__SAVE_PATH, args.dir))
    step = 0
    running = None
    best_val = None
    patience = 0
    for img, input_, target in train_loader:
        loss, grad_norm = train_step(model, img, input_, target, optimizer,
                                     args.clip_grad)
        step += 1
        running = 0.99 * running + 0.01 * loss if running else loss
        log_value('loss', loss, step)
        log_value('grad', grad_norm, step)
        print('step: {}, running loss: {:.4f}\r'.format(step, running), end='')
        sys.stdout.flush()
        if step % args.ckpt_freq == 0:
            print('\nstart validation...')
            val_loss = validate(model, val_loader)
            log_value('val_loss', val_loss, step)
            save_ckpt(model, val_loss, step, args.dir)
            scheduler.step(val_loss)
            if best_val is None or val_loss < best_val:
                best_val = val_loss
                patience = 0
            else:
                print('val loss does not decrease')
                patience += 1
            if patience > args.patience:
                break

    print('training finished, run test set')
    test_loader = get_coco_test_iter(args.max_len,
                                     args.batch_size,
                                     cuda=args.cuda)
    model.load_state_dict(torch.load(get_best_ckpt(args.dir)))
    result = test(model, test_loader, id2word, args.max_len)
    with open(join(__SAVE_PATH, '{}/result.json'.format(args.dir)), 'w') as f:
        json.dump(result, f)
예제 #4
0
def train(opt):
    '''
    data could be loaded to a dictionary with "train"/"val"/"test" pointers (Need to improve the below part)
    '''

    print_every = opt.print_every
    showatt_every = opt.print_every
    plot_every = opt.print_every

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_lw' + str(int(opt.load_wts)) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)

    logging.basicConfig(filename=opt.log_folder + full_model_name + '.log',
                        filemode='w',
                        level=logging.DEBUG,
                        format='%(asctime)s - %(levelname)s - %(message)s')

    r = opt.r

    fid = open(opt.feats_dir + 'train_list.txt')
    train_list = fid.read().splitlines()
    fid.close()

    fid = open(opt.feats_dir + 'val_list.txt')
    val_list = fid.read().splitlines()
    fid.close()

    all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       train_list + val_list)
    vocab = make_vocab(all_prompts)
    print(vocab)

    # Load training data
    train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                         train_list)
    phn2id, id2phn = phn2id2phn(vocab)
    file_list = train_prompts.keys()
    print(len(file_list), len(train_list))

    # Load stats of mfcc
    mo1 = np.load(opt.stats_dir + 'mo.npy')
    so1 = np.load(opt.stats_dir + 'so.npy')
    mo1 = mo1.astype('float32')
    so1 = so1.astype('float32')
    nml_vec1 = np.arange(0, mo1.shape[1])

    # Load stats of spectrum
    mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy')
    so2 = np.load(opt.pfnet_stats_dir + 'so.npy')
    mo2 = mo2.astype('float32')
    so2 = so2.astype('float32')
    nml_vec2 = np.arange(0, mo2.shape[1])

    # Load validation data
    val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       val_list)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1)
    if opt.residual_flag:
        if opt.r == 2:
            decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 3:
            decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 4:
            decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 5:
            decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    if opt.load_wts:
        load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r' + str(
            opt.r
        ) + '_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw0_ef1_rf1_'
        load_model_name_sfx = '.pth'

        # load model
        enc_state_dict = torch.load(load_model_name_pfx + 'enc' +
                                    load_model_name_sfx,
                                    map_location=lambda storage, loc: storage)
        encoder.load_state_dict(enc_state_dict)

        dec_state_dict = torch.load(load_model_name_pfx + 'dec' +
                                    load_model_name_sfx,
                                    map_location=lambda storage, loc: storage)
        decoder.load_state_dict(dec_state_dict)

        pfnet_state_dict = torch.load(
            load_model_name_pfx + 'pfnet' + load_model_name_sfx,
            map_location=lambda storage, loc: storage)
        pfnet.load_state_dict(pfnet_state_dict)

    encoder_optimizer = optim.Adam(encoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    pfnet_optimizer = optim.Adam(pfnet.parameters(),
                                 lr=opt.lr,
                                 betas=(opt.b1, opt.b2))

    start = time.time()
    print_loss_total = 0  # Reset every print_every
    best_val_loss = sys.maxsize

    for iter in range(1, opt.niter + 1):

        if iter == 3:
            opt.lr = opt.lr / 10

            encoder_optimizer = optim.Adam(encoder.parameters(),
                                           lr=opt.lr,
                                           betas=(opt.b1, opt.b2))
            decoder_optimizer = optim.Adam(decoder.parameters(),
                                           lr=opt.lr,
                                           betas=(opt.b1, opt.b2))
            pfnet_optimizer = optim.Adam(pfnet.parameters(),
                                         lr=opt.lr,
                                         betas=(opt.b1, opt.b2))

        for j, k in enumerate(train_prompts):

            [input_variable, input_length] = get_x(train_prompts, k, phn2id,
                                                   use_cuda)

            train_targets, train_seq_len = load_targets(
                opt.feats_dir + '/fb/', [k], '.npy', dtype, mo1, so1, nml_vec1)

            [target_variable, target_variable2,
             target_length] = get_y(train_seq_len, 0, train_targets, use_cuda,
                                    r)

            loss = 0
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            encoder_h0 = encoder.initHidden()
            encoder_c0 = encoder.initHidden()
            encoder_outputs = Variable(
                torch.zeros(input_length, encoder.hidden_size2))
            encoder_outputs = encoder_outputs.cuda(
            ) if use_cuda else encoder_outputs

            encoder_output, (encoder_hn,
                             encoder_cn) = encoder(input_variable,
                                                   (encoder_h0, encoder_c0))
            encoder_outputs = encoder_output.squeeze(1)

            decoder_input = Variable(torch.zeros(1,
                                                 op_dim))  # all - zero frame
            decoder_input = decoder_input.cuda() if use_cuda else decoder_input
            decoder_h1 = decoder.initHidden()
            decoder_c1 = decoder.initHidden()
            decoder_h2 = decoder.initHidden()
            decoder_c2 = decoder.initHidden()
            decoder_h3 = decoder.initHidden()
            decoder_c3 = decoder.initHidden()
            decoder_output_half = Variable(
                torch.zeros(target_length, r *
                            op_dim)).cuda() if use_cuda else Variable(
                                torch.zeros(target_length, r * op_dim))
            decoder_output_full = Variable(
                torch.zeros(r * target_length,
                            op_dim)).cuda() if use_cuda else Variable(
                                torch.zeros(r * target_length, op_dim))

            # Teacher forcing: Feed the target as the next input
            for di in range(target_length):
                decoder_output1, decoder_output2, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, decoder_attention = decoder(
                    decoder_input, decoder_h1, decoder_c1, decoder_h2,
                    decoder_c2, decoder_h3, decoder_c3, encoder_outputs)
                loss += criterion(decoder_output1, target_variable[di])
                decoder_input = target_variable2[di].unsqueeze(
                    0)  # Teacher forcing
                decoder_output_half[di] = decoder_output1

            loss.backward(retain_graph=True)
            encoder_optimizer.step()
            decoder_optimizer.step()

            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            # Start Post-Filtering Net
            pfnet_optimizer.zero_grad()

            for ix in range(r):
                decoder_output_full[
                    ix::r, :] = decoder_output_half[:, ix * op_dim:(ix + 1) *
                                                    op_dim]

            s1 = r * target_length

            train_targets_pfnet, train_seq_len_pfnet = load_targets(
                opt.feats_dir + '/sp/', [k], '.npy', dtype, mo2, so2, nml_vec2)
            targets_pfnet = Variable(train_targets_pfnet).cuda(
            ) if use_cuda else Variable(train_targets_pfnet)
            s2 = targets_pfnet.size()[0]
            if (s2 % r) > 0:
                targets_pfnet = targets_pfnet[:-(s2 % r), :]

            pfnet_h0 = pfnet.initHidden()
            pfnet_c0 = pfnet.initHidden()
            pfnet_outputs = Variable(
                torch.zeros(targets_pfnet.size()[0], pfnet.output_size))
            pfnet_outputs = pfnet_outputs.cuda() if use_cuda else pfnet_outputs

            pfnet_output = pfnet(decoder_output_full, (pfnet_h0, pfnet_c0))
            pfnet_outputs = pfnet_output
            loss_pfnet = criterion(pfnet_outputs, targets_pfnet)

            loss_pfnet.backward()
            pfnet_optimizer.step()
            loss_total = loss + loss_pfnet

            print_loss_total += (loss_total.data[0] / r * target_length)

            if (j + 1) % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0

                print('%s (%d %d%%) %.4f' % (timeSince(
                    start,
                    (iter * len(train_prompts) - len(train_prompts) + j) /
                    ((opt.niter + 1) * len(train_prompts))), iter, iter /
                                             opt.niter * 100, print_loss_avg))

                tf = True  # teacher forcing
                avg_total_val_loss_tf, avg_dec_val_loss_tf, decoder_attentions_tf = evaluate(
                    encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts,
                    phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim,
                    tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2)
                print('%d %0.4f %0.4f' %
                      (iter, avg_total_val_loss_tf, avg_dec_val_loss_tf))

                tf = False  # always sampling
                avg_total_val_loss_as, avg_dec_val_loss_as, decoder_attentions_pf = evaluate(
                    encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts,
                    phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim,
                    tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2)
                print('%d %0.4f %0.4f' %
                      (iter, avg_total_val_loss_as, avg_dec_val_loss_as))
                logging.debug(
                    'Epoch: ' + str(iter) + ' Update: ' +
                    str(iter * len(train_prompts) - len(train_prompts) + j) +
                    ' Avg Total Val Loss TF: ' + str(avg_total_val_loss_tf) +
                    ' Avg Total Val Loss AS: ' + str(avg_total_val_loss_as) +
                    ' Avg Dec Val Loss TF: ' + str(avg_dec_val_loss_tf) +
                    ' Avg Dec Val Loss AS: ' + str(avg_dec_val_loss_as))

                if avg_total_val_loss_tf < best_val_loss:
                    best_val_loss = avg_total_val_loss_tf
                    torch.save(
                        encoder.state_dict(),
                        '%s/%s_enc.pth' % (opt.model_folder, full_model_name))
                    torch.save(
                        decoder.state_dict(),
                        '%s/%s_dec.pth' % (opt.model_folder, full_model_name))
                    torch.save(
                        pfnet.state_dict(), '%s/%s_pfnet.pth' %
                        (opt.model_folder, full_model_name))

                encoder.train()
                decoder.train()
                pfnet.train()

            # if (j+1) % showatt_every == 0:

            #    plt.figure(1, figsize=(12, 12))
            #    plt.imshow(decoder_attentions_tf.numpy())
            #    plt.colorbar()
            #    pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight')
            #    plt.close()

            #    plt.figure(1, figsize=(12, 12))
            #    plt.imshow(decoder_attentions_pf.numpy())
            #    plt.colorbar()
            #    pylab.savefig(opt.plot_folder + full_model_name + '_' + str(j) + '_' + str(iter) + '.png', bbox_inches='tight')
            #    plt.close()

            # if (j+1) % plot_every == 0:
            #    plot_loss_avg = plot_loss_total / plot_every
            #    plot_losses.append(plot_loss_avg)
            #    plot_loss_total = 0

        gc.collect()
예제 #5
0
def test(opt):


    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_lw' + str(int(opt.load_wts)) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)
    opt.full_model_name = full_model_name

    try:
        os.makedirs(opt.synth_folder + opt.full_model_name)
        os.makedirs(opt.plot_folder + opt.full_model_name)
    except OSError:
        pass

    fid = open(opt.feats_dir + 'train_list.txt')
    train_list = fid.read().splitlines()
    fid.close()

    fid = open(opt.feats_dir + 'val_list.txt')
    val_list = fid.read().splitlines()
    fid.close()

    all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       train_list + val_list)
    vocab = make_vocab(all_prompts)
    phn2id, id2phn = phn2id2phn(vocab)
    print(vocab)

    fid = open(opt.feats_dir + 'test_list.txt')
    val_list = fid.read().splitlines()
    val_list = val_list[:10]
    fid.close()

    # Load stats of mfcc
    mo1 = np.load(opt.stats_dir + 'mo.npy')
    so1 = np.load(opt.stats_dir + 'so.npy')
    mo1 = mo1.astype('float32')
    so1 = so1.astype('float32')
    nml_vec1 = np.arange(0, mo1.shape[1])

    # Load stats of spectrum
    mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy')
    so2 = np.load(opt.pfnet_stats_dir + 'so.npy')
    mo2 = mo2.astype('float32')
    so2 = so2.astype('float32')
    nml_vec2 = np.arange(0, mo2.shape[1])

    # Load validation data
    val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data',
                                       val_list)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderCBL(vocab_size, opt.hs2, opt.hs1)
    if opt.residual_flag:
        if opt.r == 2:
            decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 3:
            decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 4:
            decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 5:
            decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(
                op_dim, opt.hs2, op_dim, 1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    load_model_name_pfx = '../../wt/' + opt.full_model_name + '_'
    load_model_name_sfx = '.pth'

    # load model
    enc_state_dict = torch.load(load_model_name_pfx + 'enc' +
                                load_model_name_sfx,
                                map_location=lambda storage, loc: storage)
    encoder.load_state_dict(enc_state_dict)

    dec_state_dict = torch.load(load_model_name_pfx + 'dec' +
                                load_model_name_sfx,
                                map_location=lambda storage, loc: storage)
    decoder.load_state_dict(dec_state_dict)

    pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' +
                                  load_model_name_sfx,
                                  map_location=lambda storage, loc: storage)
    pfnet.load_state_dict(pfnet_state_dict)

    tf = True  # teacher forcing
    avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(
        encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id,
        id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1,
        nml_vec1, mo2, so2, nml_vec2)
    print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2))

    tf = False  # professor forcing
    avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate(
        encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id,
        id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1,
        nml_vec1, mo2, so2, nml_vec2)
    print('%0.4f %0.4f' % (avg_val_loss_pf1, avg_val_loss_pf2))
예제 #6
0
def test(opt):
    
    print_every = opt.print_every
    showatt_every = opt.print_every
    plot_every = opt.print_every

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1) + '_r' + str(opt.r) \
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_lw' + str(int(opt.load_wts)) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)
    opt.full_model_name = full_model_name
    
    r = opt.r
 
    fid = open(opt.feats_dir + 'train_list.txt')
    train_list = fid.read().splitlines()
    fid.close()

    fid = open(opt.feats_dir + 'val_list.txt')
    val_list = fid.read().splitlines()
    val_list = val_list
    fid.close()

    all_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list + val_list)
    vocab = make_vocab(all_prompts)
    print(vocab)

    fid = open(opt.feats_dir + 'test_list.txt')
    val_list = fid.read().splitlines()
    val_list = val_list[:10]
    fid.close()



    # Load training data
    train_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', train_list)
    #vocab = make_vocab(train_prompts)
    #print(vocab)
    phn2id, id2phn = phn2id2phn(vocab)
    file_list = train_prompts.keys()
    print(len(file_list), len(train_list))
    #save_stats_suffstats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir)
    #save_stats_suffstats(opt.feats_dir + '/sp/', file_list, '.npy', dtype, opt.pfnet_stats_dir)

    # save_stats(opt.feats_dir + '/fb/', file_list, '.npy', dtype, opt.stats_dir)
    # save_stats(opt.feats_dir + phase + '/log_mag_spec/',
    # file_list, opt.pfnet_audio_feats_ext, dtype, opt.pfnet_stats_dir)
    # exit()

    # Load stats of mfcc
    mo1 = np.load(opt.stats_dir + 'mo.npy')
    so1 = np.load(opt.stats_dir + 'so.npy')
    mo1 = mo1.astype('float32')
    so1 = so1.astype('float32')
    nml_vec1 = np.arange(0, mo1.shape[1])

    # Load stats of spectrum
    mo2 = np.load(opt.pfnet_stats_dir + 'mo.npy')
    so2 = np.load(opt.pfnet_stats_dir + 'so.npy')
    mo2 = mo2.astype('float32')
    so2 = so2.astype('float32')
    nml_vec2 = np.arange(0, mo2.shape[1])

    #train_targets, train_seq_len = load_targets(opt.feats_dir + phase
    #                                            + '/audio_feats/', file_list,
    #                                            opt.audio_feats_ext,
    #                                            dtype, mo1, so1, nml_vec1)

    # Load validation data
    val_prompts = make_prompts_dict_v2(opt.feats_dir + 'txt.done.data', val_list)
    #file_list = val_prompts.keys()
    #val_targets, val_seq_len = load_targets(opt.feats_dir + '/fb/', val_list, '.npy', dtype, mo1, so1, nml_vec1)
    #print(val_seq_len)
    #print(val_targets.shape)
    #print(val_list)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1)
    if opt.residual_flag:
        if opt.r == 2:
            decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 3:
            decoder = decoders.AttnDecoderLSTM3L_R3_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 4:
            decoder = decoders.AttnDecoderLSTM3L_R4_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
        if opt.r == 5:
            decoder = decoders.AttnDecoderLSTM3L_R5_Rescon(op_dim, opt.hs2, op_dim, 1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1, opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    if opt.load_wts:
        load_model_name_pfx = '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_r3_lr3e-05_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_lw1_ef0_rf1_'
        load_model_name_sfx = '_epoch_2999_5.pth'

        # load model
        enc_state_dict = torch.load(load_model_name_pfx + 'enc' + load_model_name_sfx, map_location=lambda storage, loc: storage)
        encoder.load_state_dict(enc_state_dict)

        dec_state_dict = torch.load(load_model_name_pfx + 'dec' + load_model_name_sfx, map_location=lambda storage, loc: storage)
        decoder.load_state_dict(dec_state_dict)

        pfnet_state_dict = torch.load(load_model_name_pfx + 'pfnet' + load_model_name_sfx, map_location=lambda storage, loc: storage)
        pfnet.load_state_dict(pfnet_state_dict)
    
    
    tf = True # teacher forcing
    avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt, mo1, so1, nml_vec1, mo2, so2, nml_vec2)
    print('%0.4f %0.4f' % (avg_val_loss_tf1, avg_val_loss_tf2))
예제 #7
0
파일: main.py 프로젝트: ganlubbq/nntoolbox
def train(opt):
    '''
    data could be loaded to a dictionary with "train"/"val"/"test" pointers (Need to improve the below part)
    '''

    print_every = opt.print_every
    showatt_every = opt.print_every
    plot_every = opt.print_every

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) \
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)

    logging.basicConfig(filename=opt.log_folder + full_model_name + '.log',
                        filemode='w',
                        level=logging.DEBUG,
                        format='%(asctime)s - %(levelname)s - %(message)s')

    # Load training data
    phase = 'train'
    train_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    vocab = make_vocab(train_prompts)
    phn2id, id2phn = phn2id2phn(vocab)
    file_list = train_prompts.keys()
    save_stats(opt.feats_dir + phase + '/audio_feats/', file_list,
               opt.audio_feats_ext, dtype, '../stats/')
    train_targets, train_seq_len = load_targets(
        opt.feats_dir + phase + '/audio_feats/', file_list,
        opt.audio_feats_ext, dtype, opt.stats_dir)

    # Load validation data
    phase = 'val'
    val_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    file_list = val_prompts.keys()
    val_targets, val_seq_len = load_targets(
        opt.feats_dir + phase + '/audio_feats/', file_list,
        opt.audio_feats_ext, dtype, opt.stats_dir)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1)
    if opt.residual_flag:
        decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim,
                                                       1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    criterion = torch.nn.L1Loss(size_average=False)

    encoder_optimizer = optim.Adam(encoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))
    decoder_optimizer = optim.Adam(decoder.parameters(),
                                   lr=opt.lr,
                                   betas=(opt.b1, opt.b2))

    start = time.time()
    plot_losses = []
    print_loss_total = 0  # Reset every print_every
    plot_loss_total = 0  # Reset every plot_every
    best_val_loss = 1000000

    for iter in range(1, opt.niter + 1):

        for j, k in enumerate(train_prompts):

            [input_variable,
             input_length] = get_x_1hot(train_prompts, k, phn2id, vocab_size,
                                        use_cuda)
            [target_variable, target_variable2,
             target_length] = get_y(train_seq_len, j, train_targets, use_cuda)

            loss = 0
            encoder_optimizer.zero_grad()
            decoder_optimizer.zero_grad()

            encoder_h0 = encoder.initHidden()
            encoder_c0 = encoder.initHidden()
            encoder_outputs = Variable(
                torch.zeros(input_length, encoder.hidden_size))
            encoder_outputs = encoder_outputs.cuda(
            ) if use_cuda else encoder_outputs

            encoder_output, (encoder_hn,
                             encoder_cn) = encoder(input_variable,
                                                   (encoder_h0, encoder_c0))
            encoder_outputs = encoder_output.squeeze(1)

            decoder_input = Variable(torch.zeros(1,
                                                 op_dim))  # all - zero frame
            decoder_input = decoder_input.cuda() if use_cuda else decoder_input
            decoder_h1 = decoder.initHidden()
            decoder_c1 = decoder.initHidden()
            decoder_h2 = decoder.initHidden()
            decoder_c2 = decoder.initHidden()
            decoder_h3 = decoder.initHidden()
            decoder_c3 = decoder.initHidden()

            # Teacher forcing: Feed the target as the next input
            for di in range(target_length):
                decoder_output1, decoder_output2, decoder_h1, decoder_c1, decoder_h2, decoder_c2, decoder_h3, decoder_c3, decoder_attention = decoder(
                    decoder_input, decoder_h1, decoder_c1, decoder_h2,
                    decoder_c2, decoder_h3, decoder_c3, encoder_outputs)
                loss += criterion(decoder_output1, target_variable[di])
                decoder_input = target_variable2[di].unsqueeze(
                    0)  # Teacher forcing

            loss.backward()
            #torch.nn.utils.clip_grad_norm(encoder.parameters(), 1)
            #torch.nn.utils.clip_grad_norm(decoder.parameters(), 1)
            encoder_optimizer.step()
            decoder_optimizer.step()

            print_loss_total += (loss.data[0] / target_length)
            plot_loss_total += (loss.data[0] / target_length)

            if (j + 1) % print_every == 0:
                print_loss_avg = print_loss_total / print_every
                print_loss_total = 0
                print('%s (%d %d%%) %.4f' % (timeSince(
                    start,
                    (iter * len(train_prompts) - len(train_prompts) + j) /
                    ((opt.niter + 1) * len(train_prompts))), iter, iter /
                                             opt.niter * 100, print_loss_avg))

                tf = True  # teacher forcing
                avg_val_loss_tf, decoder_attentions_tf = evaluate(
                    encoder.eval(), decoder.eval(), val_prompts, val_targets,
                    val_seq_len, phn2id, id2phn, vocab_size, use_cuda,
                    criterion, op_dim, tf)
                print('%d %0.4f' % (iter, avg_val_loss_tf))

                tf = False  # professor forcing
                avg_val_loss_pf, decoder_attentions_pf = evaluate(
                    encoder.eval(), decoder.eval(), val_prompts, val_targets,
                    val_seq_len, phn2id, id2phn, vocab_size, use_cuda,
                    criterion, op_dim, tf)
                print('%d %0.4f' % (iter, avg_val_loss_pf))
                logging.debug('Epoch: ' + str(iter) + ' Update: ' +
                              str(iter * len(train_prompts) -
                                  len(train_prompts) + j) +
                              ' Avg Val Loss TF: ' + str(avg_val_loss_tf) +
                              ' Avg Val Loss PF: ' + str(avg_val_loss_pf))

                if avg_val_loss_tf < best_val_loss:
                    best_val_loss = avg_val_loss_tf
                    torch.save(
                        encoder.state_dict(), '%s/%s_enc_epoch_%d_%d.pth' %
                        (opt.model_folder, full_model_name, j, iter))
                    torch.save(
                        decoder.state_dict(), '%s/%s_dec_epoch_%d_%d.pth' %
                        (opt.model_folder, full_model_name, j, iter))

                encoder.train()
                decoder.train()

            if (j + 1) % showatt_every == 0:

                plt.figure(1, figsize=(12, 12))
                plt.imshow(decoder_attentions_tf.numpy())
                plt.colorbar()
                pylab.savefig(opt.plot_folder + full_model_name + '_' +
                              str(j) + '_' + str(iter) + '.png',
                              bbox_inches='tight')
                plt.close()

                plt.figure(1, figsize=(12, 12))
                plt.imshow(decoder_attentions_pf.numpy())
                plt.colorbar()
                pylab.savefig(opt.plot_folder + full_model_name + '_' +
                              str(j) + '_' + str(iter) + '.png',
                              bbox_inches='tight')
                plt.close()

            if (j + 1) % plot_every == 0:
                plot_loss_avg = plot_loss_total / plot_every
                plot_losses.append(plot_loss_avg)
                plot_loss_total = 0

    showPlot(plot_losses)
예제 #8
0
def test(opt):

    full_model_name = opt.model_name \
    + '_hs1' + str(opt.hs1) + '_hs2' + str(opt.hs2) + '_pfnet_hs1' + str(opt.pfnet_hs1)\
    + '_lr' + str(opt.lr) + '_b1' + str(opt.b1) + '_b2' + str(opt.b2) \
    + '_dp' + str(opt.dp) \
    + '_gc' + str(opt.gcth) \
    + '_wtinit' + str(opt.wtinit_meth) \
    + '_ef' + str(int(opt.embedding_flag)) \
    + '_rf' + str(int(opt.residual_flag))

    print(full_model_name)
    opt.full_model_name = full_model_name

    # Load training data
    phase = 'train'
    train_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    vocab = make_vocab(train_prompts)
    phn2id, id2phn = phn2id2phn(vocab)

    # Load validation data
    phase = 'test'
    val_prompts = make_prompts_dict(opt.etc_dir + phase + '.done.data')
    file_list = val_prompts.keys()
    file_list = file_list[:5]
    val_targets, val_seq_len = load_targets(
        opt.feats_dir + phase + '/audio_feats/', file_list,
        opt.audio_feats_ext, dtype, opt.stats_dir)

    # Initialize model
    vocab_size = len(vocab)
    op_dim = 60
    encoder = encoders.EncoderBLSTM_WOE(vocab_size, opt.hs1)
    if opt.residual_flag:
        decoder = decoders.AttnDecoderLSTM3L_R2_Rescon(op_dim, opt.hs2, op_dim,
                                                       1, opt.dp)
    else:
        decoder = decoders.AttnDecoderLSTM3L_R2(op_dim, opt.hs2, op_dim, 1,
                                                opt.dp)

    op_dim1 = 513
    pfnet = encoders.EncoderBLSTM_WOE_1L(op_dim, opt.pfnet_hs1, op_dim1)

    encoder = encoder.cuda() if use_cuda else encoder
    decoder = decoder.cuda() if use_cuda else decoder
    pfnet = pfnet.cuda() if use_cuda else pfnet
    criterion = torch.nn.L1Loss(size_average=False)

    # load model
    enc_state_dict = torch.load(
        '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_enc_epoch_999_18.pth',
        map_location=lambda storage, loc: storage)
    encoder.load_state_dict(enc_state_dict)

    dec_state_dict = torch.load(
        '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_dec_epoch_999_18.pth',
        map_location=lambda storage, loc: storage)
    decoder.load_state_dict(dec_state_dict)

    pfnet_state_dict = torch.load(
        '../../wt/s2s_enc_blstm_dec_lstm3l_pfnet_blstm1L_nopfnetloss__hs1250_hs2500_pfnet_hs1250_lr0.0003_b10.9_b20.99_dp0.5_gc0.0_wtinitdefault_init_ef0_rf1_pfnet_epoch_999_18.pth',
        map_location=lambda storage, loc: storage)
    pfnet.load_state_dict(pfnet_state_dict)

    #tf = True # teacher forcing
    #avg_val_loss_tf1, avg_val_loss_tf2, decoder_attentions_tf = evaluate(encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, val_targets, val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim, tf, opt)
    #print('%d %0.4f %0.4f' % (iter, avg_val_loss_tf1, avg_val_loss_tf2))

    tf = False  # professor forcing
    avg_val_loss_pf1, avg_val_loss_pf2, decoder_attentions_pf = evaluate(
        encoder.eval(), decoder.eval(), pfnet.eval(), val_prompts, val_targets,
        val_seq_len, phn2id, id2phn, vocab_size, use_cuda, criterion, op_dim,
        tf, opt)
    print('%d %0.4f %0.4f' % (iter, avg_val_loss_pf1, avg_val_loss_pf2))