コード例 #1
0
def test():
    if args.param_file is None:
        print('please specify the saved param file.')
        exit(-1)
    else:
        print('loading saved parameters from ' + args.param_file + '...')
        checkpoint = torch.load(args.param_file, map_location=args.device)
        train_args = checkpoint['args']
        voc = checkpoint['voc']
        print('done')

        print('arguments for train:')
        print(train_args)
        
        print('rebuilding model...')
        model = Set2Seq2Seq(voc.num_words).to(args.device)
        model.load_state_dict(checkpoint['model'])
        param_optimizer = train_args.optimiser(model.parameters(), lr=args.learning_rate)
        decoder_optimizer = train_args.optimiser(model.speaker.decoder.parameters(), 
                                        lr=args.learning_rate * args.decoder_ratio)
        param_optimizer.load_state_dict(checkpoint['opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])
        print('done')

    print('loading test data...')
    test_set = FruitSeqDataset(voc, dataset_file_path=args.test_file)
    print('done')
    
    test_seq_acc, test_tok_acc, test_loss = eval_model(model, test_set)
    print("[TEST]Loss: {:.4f}; Seq-level Accuracy: {:.4f}; Tok-level Accuracy: {:.4f}".format(
                test_loss, test_seq_acc * 100, test_tok_acc * 100)
         )
コード例 #2
0
def test():
    print('building model...')
    voc = Voc()
    seq2seq = Seq2Seq(voc.num_words).to(args.device)
    param_optimizer = args.optimiser(seq2seq.parameters(),
                                     lr=args.learning_rate)
    decoder_optimizer = args.optimiser(seq2seq.decoder.parameters(),
                                       lr=args.learning_rate *
                                       args.decoder_ratio)
    print('done')

    if args.param_file is None:
        print('please specify the saved param file.')
        exit(-1)
    else:
        print('loading saved parameters from ' + args.param_file + '...')
        checkpoint = torch.load(args.param_file)
        seq2seq.load_state_dict(checkpoint['model'])
        param_optimizer.load_state_dict(checkpoint['opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])
        voc = checkpoint['voc']
        print('done')

    print('loading test data...')
    test_set = FruitSeqDataset(voc, dataset_file_path=args.test_file)
    print('done')

    test_seq_acc, test_tok_acc, test_loss = eval_model(seq2seq, test_set)
    print(
        "[TEST]Loss: {:.4f}; Seq-level Accuracy: {:.4f}; Tok-level Accuracy: {:.4f}"
        .format(test_loss, test_seq_acc * 100, test_tok_acc * 100))
コード例 #3
0
def main():
    print('building vocabulary...')
    voc = Voc()
    print('done')

    print('loading data and building batches...')
    data_set = FruitSeqDataset(voc, dataset_file_path=DATA_FILE, batch_size=1)
    str_set = data_set.load_stringset(DATA_FILE)
    print('done')

    print('rebuilding model from saved parameters in ' + args.param_file +
          '...')
    model = Set2Seq2Seq(voc.num_words).to(args.device)
    checkpoint = torch.load(args.param_file, map_location=args.device)
    train_args = checkpoint['args']
    model.load_state_dict(checkpoint['model'])
    voc = checkpoint['voc']
    print('done')

    model.eval()

    print('iterating data set...')
    out_file = open(OUT_FILE, mode='a')
    iterate_dataset(model, voc, str_set, data_set, out_file, train_args)
コード例 #4
0
def reproduce_input_hidden_pairs(model, voc, dataset_file_path):
    repro_dataset = FruitSeqDataset(voc,
                                    dataset_file_path=dataset_file_path,
                                    batch_size=1)
    repro_strset = load_stringset(dataset_file_path)

    def _instr2coordinate_(in_str, voc):
        coordinate = []
        for i in range(voc.num_words - 3):
            coordinate.append(in_str.count(chr(65 + i)))
        return np.asarray(coordinate)

    pair_set = []
    for idx, data_batch in enumerate(repro_dataset):
        input_var = data_batch['input']
        input_mask = data_batch['input_mask']
        hidden, _ = model.encoder(model.embedding(input_var.t()), input_mask)
        hidden = hidden.squeeze().detach().cpu().numpy()
        in_vec = _instr2coordinate_(repro_strset[idx], voc)
        pair_set.append([in_vec, hidden])

    return pair_set
コード例 #5
0
def get_batches4sim_check(voc, dataset_file_path=args.data_file):
    in_set = FruitSeqDataset.load_stringset(dataset_file_path)
    batch_set = ChooseDataset(voc,
                              batch_size=1,
                              dataset_file_path=dataset_file_path)
    return in_set, batch_set
コード例 #6
0
def train():
    print('building vocabulary...')
    voc = Voc()
    print('done')

    print('loading data and building batches...')
    train_set = FruitSeqDataset(voc, dataset_file_path=args.train_file)
    dev_set = FruitSeqDataset(voc, dataset_file_path=args.dev_file)
    # test_set = FruitSeqDataset(voc, dataset_file_path=TEST_FILE_PATH)
    print('done')

    print('building model...')
    seq2seq = Seq2Seq(voc.num_words).to(args.device)
    param_optimizer = args.optimiser(seq2seq.parameters(),
                                     lr=args.learning_rate)
    decoder_optimizer = args.optimiser(seq2seq.decoder.parameters(),
                                       lr=args.learning_rate *
                                       args.speaker_ratio)
    if args.param_file is not None:
        print('\tloading saved parameters from ' + args.param_file + '...')
        checkpoint = torch.load(args.param_file)
        seq2seq.load_state_dict(checkpoint['model'])
        param_optimizer.load_state_dict(checkpoint['opt'])
        decoder_optimizer.load_state_dict(checkpoint['de_opt'])
        voc = checkpoint['voc']
        print('\tdone')
    print('done')

    print('initialising...')
    start_iteration = 1
    print_loss = 0.
    print_seq_acc = 0.
    print_tok_acc = 0.
    max_dev_seq_acc = 0.
    training_losses = []
    training_tok_acc = []
    training_seq_acc = []
    training_sim = []
    eval_tok_acc = []
    eval_seq_acc = []
    print('done')

    print('training...')
    for iter in range(start_iteration, args.iter_num + 1):
        for idx, data_batch in enumerate(train_set):
            seq_acc, tok_acc, loss = train_epoch(seq2seq, data_batch,
                                                 param_optimizer,
                                                 decoder_optimizer)
            print_loss += loss
            print_seq_acc += seq_acc
            print_tok_acc += tok_acc

        if iter % args.print_freq == 0:
            print_loss_avg = print_loss / (args.print_freq * len(train_set))
            print_seq_acc_avg = print_seq_acc / (args.print_freq *
                                                 len(train_set))
            print_tok_acc_avg = print_tok_acc / (args.print_freq *
                                                 len(train_set))
            print(
                "Iteration: {}; Percent complete: {:.1f}%; Avg loss: {:.4f}; Avg seq acc: {:.4f}; Avg tok acc: {:.4f}"
                .format(iter, iter / args.iter_num * 100, print_loss_avg,
                        print_seq_acc_avg, print_tok_acc_avg))
            training_seq_acc.append(print_seq_acc_avg)
            training_tok_acc.append(print_tok_acc_avg)
            training_losses.append(print_loss_avg)
            print_seq_acc = 0.
            print_tok_acc = 0.
            print_loss = 0.

        if iter % args.eval_freq == 0:
            dev_seq_acc, dev_tok_acc, dev_loss = eval_model(seq2seq, dev_set)
            if dev_seq_acc > max_dev_seq_acc:
                max_dev_seq_acc = dev_seq_acc
            eval_seq_acc.append(dev_seq_acc)
            eval_tok_acc.append(dev_tok_acc)

            print(
                "[EVAL]Iteration: {}; Loss: {:.4f}; Avg Seq Acc: {:.4f}; Avg Tok Acc: {:.4f}; Best Seq Acc: {:.4f}"
                .format(iter, dev_loss, dev_seq_acc, dev_tok_acc,
                        max_dev_seq_acc))

        if iter % args.save_freq == 0:
            directory = os.path.join(args.save_dir, 'seq2seq')
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration': iter,
                    'model': seq2seq.state_dict(),
                    'opt': param_optimizer.state_dict(),
                    'de_opt': decoder_optimizer.state_dict(),
                    'loss': loss,
                    'voc': voc,
                    'args': args,
                    'records': {
                        'training_loss': training_losses,
                        'training_tok_acc': training_tok_acc,
                        'training_seq_acc': training_seq_acc,
                        'training_sim': training_sim,
                        'eval_tok_acc': eval_tok_acc,
                        'eval_seq_acc': eval_seq_acc
                    }
                },
                os.path.join(
                    directory, '{}_{}_{}.tar'.format(args.seed, iter,
                                                     'checkpoint')))
コード例 #7
0
def train():
    print('building vocabulary...')
    voc = Voc()
    print('done')

    print('loading data and building batches...')
    train_set = FruitSeqDataset(voc, dataset_file_path=args.train_file)
    dev_set = FruitSeqDataset(voc, dataset_file_path=args.dev_file)
    learn_set = FruitSeqDataset(voc, dataset_file_path=args.train_file, batch_size=1)
    print('done')
        
    if args.param_file is not None:
        print('loading saved parameters from ' + args.param_file + '...')
        checkpoint = torch.load(args.param_file, map_location=args.device)
        train_args = checkpoint['args']
        voc = checkpoint['voc']
        print('done')

        print('arguments for training:')
        print(train_args)

        print('rebuilding model...')

        model = Set2Seq2Seq(voc.num_words).to(args.device)
        model.load_state_dict(checkpoint['model'])
        print('\tdone')
    else:
        print('building model...')
        model = Set2Seq2Seq(voc.num_words).to(args.device)
        print('done')

    print('preparing data for testing topological similarity...')
    sim_chk_inset, sim_chk_batchset = get_batches4sim_check(voc, args.data_file)
    print('done')
    
    print('initialising...')
    start_iteration = 1
    training_losses = []
    training_tok_acc = []
    training_seq_acc = []
    training_in_spkh_sim = []
    training_in_msg_sim = []
    training_in_lish_sim = []
    eval_tok_acc = []
    eval_seq_acc = []
    print('done')

    in_spk_sim, in_msg_sim, in_lis_sim = sim_check(
        model, sim_chk_inset, sim_chk_batchset
    )
    print('[SIM]Iteration: {}; In-SpkHidden Sim: {:.4f}; In-Msg Sim: {:.4f}; In-LisHidden Sim: {:.4f}'.format(
                0, in_spk_sim, in_msg_sim, in_lis_sim))

    print('training...')
    for iter in range(start_iteration, args.num_generation+1):
        training_records = train_generation(
            model, train_set, dev_set, learn_set, sim_chk_inset, sim_chk_batchset,
            generation_idx=iter
        )

        training_losses += training_records[0]
        training_tok_acc += training_records[1]
        training_seq_acc += training_records[2]
        training_in_spkh_sim += training_records[3]
        training_in_msg_sim+= training_records[4]
        training_in_lish_sim += training_records[5]
        eval_tok_acc += training_records[6]
        eval_seq_acc += training_records[7]
        
        if iter % args.save_freq == 0:
            path_join = 'set2seq2seq_3phases_' + str(args.num_words) + '_' + args.msg_mode
            path_join += '_hard' if not args.soft else '_soft'
            directory = os.path.join(args.save_dir, path_join)
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save({
                'generation': iter,
                'model': model.state_dict(),
                'voc': voc,
                'args': args,
                'records': {
                    'training_loss': training_losses,
                    'training_tok_acc': training_tok_acc,
                    'training_seq_acc': training_seq_acc,
                    'training_in_spkh_sim': training_in_spkh_sim,
                    'training_in_msg_sim': training_in_msg_sim,
                    'training_in_lish_sim': training_in_lish_sim,
                    'eval_tok_acc': eval_tok_acc,
                    'eval_seq_acc': eval_seq_acc
                }
            }, os.path.join(directory, '{}_{:.4f}_{}.tar'.format(iter, eval_seq_acc[-1], 'checkpoint')))
コード例 #8
0
def train():
    print('building vocabulary...')
    voc = Voc()
    print('done')

    print('loading data and building batches...')
    train_set = FruitSeqDataset(voc, dataset_file_path=args.train_file)
    dev_set = FruitSeqDataset(voc, dataset_file_path=args.dev_file)
    # test_set = FruitSeqDataset(voc, dataset_file_path=TEST_FILE_PATH)
    print('done')

    if args.param_file is not None:
        print('loading saved parameters from ' + args.param_file + '...')
        checkpoint = torch.load(args.param_file, map_location=args.device)
        train_args = checkpoint['args']
        voc = checkpoint['voc']
        print('done')

        print('arguments for training:')
        print(train_args)

        print('rebuilding model...')

        model = Set2Seq2Seq(voc.num_words).to(args.device)
        model.load_state_dict(checkpoint['model'])
        model_optimiser = train_args.optimiser(model.parameters(),
                                               lr=train_args.learning_rate)
        speaker_optimiser = train_args.optimiser(model.speaker.parameters(),
                                                 lr=train_args.learning_rate *
                                                 train_args.speaker_ratio)
        listner_optimiser = train_args.optimiser(model.listener.parameters(),
                                                 lr=train_args.learning_rate *
                                                 train_args.speaker_ratio)
        print('\tdone')
    else:
        print('building model...')
        model = Set2Seq2Seq(voc.num_words).to(args.device)
        model_optimiser = args.optimiser(model.parameters(),
                                         lr=args.learning_rate)
        speaker_optimiser = args.optimiser(model.speaker.decoder.parameters(),
                                           lr=args.learning_rate *
                                           args.speaker_ratio)
        listner_optimiser = args.optimiser(model.listener.parameters(),
                                           lr=args.learning_rate *
                                           args.listener_ratio)
        print('done')

    print('preparing data for testing topological similarity...')
    sim_chk_inset, sim_chk_batchset = get_batches4sim_check(
        voc, args.data_file)
    print('done')

    print('initialising...')
    start_iteration = 1
    print_loss = 0.
    print_seq_acc = 0.
    print_tok_acc = 0.
    max_dev_seq_acc = 0.
    max_dev_tok_acc = 0.
    training_losses = []
    training_tok_acc = []
    training_seq_acc = []
    training_in_spkh_sim = []
    training_in_msg_sim = []
    training_in_lish_sim = []
    training_spkh_lish_sim = []
    training_mi = []
    eval_tok_acc = []
    eval_seq_acc = []
    print('done')

    in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim = sim_check(
        model, sim_chk_inset, sim_chk_batchset)
    mi_sim = mi_check(model, sim_chk_batchset)
    print(
        '[SIM]Iteration: {}; In-SpkH Sim: {:.4f}; In-Msg Sim: {:.4f}; In-LisH Sim: {:.4f}; SpkH-LisH Sim: {:.4f}; In-Msg-MI: {:.4f}'
        .format(0, in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim, mi_sim))

    print('training...')
    for iter in range(start_iteration, args.iter_num + 1):
        for idx, data_batch in enumerate(train_set):
            if len(eval_seq_acc) > 10:
                tau = tau_scheduler(sum(eval_seq_acc[-10:]) / 10.)
            else:
                tau = tau_scheduler(0.)

            seq_acc, tok_acc, loss = train_epoch(model, data_batch, tau,
                                                 model_optimiser,
                                                 speaker_optimiser,
                                                 listner_optimiser)
            print_loss += loss
            print_seq_acc += seq_acc
            print_tok_acc += tok_acc

        if iter % args.print_freq == 0:
            print_loss_avg = print_loss / (args.print_freq * len(train_set))
            print_seq_acc_avg = print_seq_acc / (args.print_freq *
                                                 len(train_set))
            print_tok_acc_avg = print_tok_acc / (args.print_freq *
                                                 len(train_set))

            print(
                "Iteration: {}; Percent complete: {:.1f}%; Avg loss: {:.4f}; Avg seq acc: {:.4f}; Avg tok acc: {:.4f}"
                .format(iter, iter / args.iter_num * 100, print_loss_avg,
                        print_seq_acc_avg, print_tok_acc_avg))
            training_seq_acc.append(print_seq_acc_avg)
            training_tok_acc.append(print_tok_acc_avg)
            training_losses.append(print_loss_avg)
            print_seq_acc = 0.
            print_tok_acc = 0.
            print_loss = 0.

        if iter % args.eval_freq == 0:
            dev_seq_acc, dev_tok_acc, dev_loss = eval_model(model, dev_set)
            if dev_seq_acc > max_dev_seq_acc:
                max_dev_seq_acc = dev_seq_acc
            if dev_tok_acc > max_dev_tok_acc:
                max_dev_tok_acc = dev_tok_acc
            eval_seq_acc.append(dev_seq_acc)
            eval_tok_acc.append(dev_tok_acc)
            print(
                "[EVAL]Iteration: {}; Loss: {:.4f}; Avg Seq Acc: {:.4f}; Avg Tok Acc: {:.4f}; Best Seq Acc: {:.4f}"
                .format(iter, dev_loss, dev_seq_acc, dev_tok_acc,
                        max_dev_seq_acc))

        if iter % args.sim_chk_freq == 0:
            in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim = sim_check(
                model, sim_chk_inset, sim_chk_batchset)
            mi_sim = mi_check(model, sim_chk_batchset)

            training_in_spkh_sim.append(in_spk_sim)
            training_in_msg_sim.append(in_msg_sim)
            training_in_lish_sim.append(in_lis_sim)
            training_spkh_lish_sim.append(spk_lis_sim)
            training_mi.append(mi_sim)

            print(
                '[SIM]Iteration: {}; In-SpkH Sim: {:.4f}; In-Msg Sim: {:.4f}; In-LisH Sim: {:.4f}; SpkH-LisH Sim: {:.4f}; In-Msg-MI: {:.4f}'
                .format(0, in_spk_sim, in_msg_sim, in_lis_sim, spk_lis_sim,
                        mi_sim))

        if iter % args.l_reset_freq == 0 and not args.l_reset_freq == -1:
            model.listener.reset_params()
            print('[RESET] reset listener')

        if iter % args.save_freq == 0:
            path_join = 'set2seq2seq_' + str(
                args.num_words) + '_' + args.msg_mode
            path_join += '_hard' if not args.soft else '_soft'
            directory = os.path.join(args.save_dir, path_join)
            if not os.path.exists(directory):
                os.makedirs(directory)
            torch.save(
                {
                    'iteration':
                    iter,
                    'model':
                    model.state_dict(),
                    'opt': [
                        model_optimiser.state_dict(),
                        speaker_optimiser.state_dict(),
                        listner_optimiser.state_dict()
                    ],
                    'loss':
                    loss,
                    'voc':
                    voc,
                    'args':
                    args,
                    'records': {
                        'training_loss': training_losses,
                        'training_tok_acc': training_tok_acc,
                        'training_seq_acc': training_seq_acc,
                        'training_in_spkh_sim': training_in_spkh_sim,
                        'training_in_msg_sim': training_in_msg_sim,
                        'training_in_lish_sim': training_in_lish_sim,
                        'training_spkh_lish_sim': training_spkh_lish_sim,
                        'training_mi': training_mi,
                        'eval_tok_acc': eval_tok_acc,
                        'eval_seq_acc': eval_seq_acc
                    }
                },
                os.path.join(
                    directory,
                    '{}_{:.4f}_{}.tar'.format(iter, dev_seq_acc,
                                              'checkpoint')))
コード例 #9
0
def main(
    model_name='Img2Seq2Choice',
    dataset_name='ImgChooseDataset',
    out_file_path='data/tmp.txt',
):
    if args.param_file is not None:
        checkpoint = torch.load(args.param_file,
                                map_location=torch.device('cpu'))
    else:
        raise ValueError

    print('rebuilding vocabulary and model...')
    voc = checkpoint[
        'voc'] if model_name == 'Set2Seq2Seq' or model_name == 'Set2Seq2Choice' else None
    train_args = checkpoint['args']
    print(train_args)

    if model_name == 'Img2Seq2Choice':
        model = Img2Seq2Choice(msg_length=train_args.max_msg_len,
                               msg_vocsize=train_args.msg_vocsize,
                               hidden_size=train_args.hidden_size,
                               dropout=train_args.dropout_ratio,
                               msg_mode=train_args.msg_mode).to(
                                   torch.device('cpu'))
    elif model_name == 'Set2Seq2Seq':
        model = Set2Seq2Seq(voc.num_words,
                            msg_length=train_args.max_msg_len,
                            msg_vocsize=train_args.msg_vocsize,
                            hidden_size=train_args.hidden_size,
                            dropout=train_args.dropout_ratio,
                            msg_mode=train_args.msg_mode).to(
                                torch.device('cpu'))
    elif model_name == 'Set2Seq2Choice':
        model = Set2Seq2Choice(voc.num_words,
                               msg_length=train_args.max_msg_len,
                               msg_vocsize=train_args.msg_vocsize,
                               hidden_size=train_args.hidden_size,
                               dropout=train_args.dropout_ratio,
                               msg_mode=train_args.msg_mode).to(
                                   torch.device('cpu'))
    else:
        raise NotImplementedError

    model.load_state_dict(checkpoint['model'])
    model.eval()
    print('done')

    print('loading and building batch dataset...')
    if dataset_name == 'ImgChooseDataset':
        batch_set = ImgChooseDataset(dataset_dir_path=args.data_file,
                                     batch_size=1,
                                     device=torch.device('cpu'))
        in_set = [batch['correct']['label'][0] for batch in batch_set]
    elif dataset_name == 'FruitSeqDataset':
        batch_set = FruitSeqDataset(voc,
                                    dataset_file_path=args.data_file,
                                    batch_size=1,
                                    device=torch.device('cpu'))
        in_set = FruitSeqDataset.load_stringset(args.data_file)
    elif dataset_name == 'ChooseDataset':
        batch_set = ChooseDataset(voc,
                                  dataset_file_path=args.data_file,
                                  batch_size=1,
                                  device=torch.device('cpu'))
        in_set = FruitSeqDataset.load_stringset(args.data_file)
    print('done')

    build_listener_training_file(model, in_set, batch_set, out_file_path)