コード例 #1
0
def main():
    print('building vocabulary...')
    voc = Voc()
    print('done')

    print('loading data and building batches...')
    data_set = FruitSeqDataset(voc, dataset_file_path=DATA_FILE, batch_size=1)
    str_set = data_set.load_stringset(DATA_FILE)
    print('done')

    print('rebuilding model from saved parameters in ' + args.param_file +
          '...')
    model = Set2Seq2Seq(voc.num_words).to(args.device)
    checkpoint = torch.load(args.param_file, map_location=args.device)
    train_args = checkpoint['args']
    model.load_state_dict(checkpoint['model'])
    voc = checkpoint['voc']
    print('done')

    model.eval()

    print('iterating data set...')
    out_file = open(OUT_FILE, mode='a')
    iterate_dataset(model, voc, str_set, data_set, out_file, train_args)
コード例 #2
0
def get_batches4sim_check(voc, dataset_file_path=args.data_file):
    in_set = FruitSeqDataset.load_stringset(dataset_file_path)
    batch_set = ChooseDataset(voc,
                              batch_size=1,
                              dataset_file_path=dataset_file_path)
    return in_set, batch_set
コード例 #3
0
def main(
    model_name='Img2Seq2Choice',
    dataset_name='ImgChooseDataset',
    out_file_path='data/tmp.txt',
):
    if args.param_file is not None:
        checkpoint = torch.load(args.param_file,
                                map_location=torch.device('cpu'))
    else:
        raise ValueError

    print('rebuilding vocabulary and model...')
    voc = checkpoint[
        'voc'] if model_name == 'Set2Seq2Seq' or model_name == 'Set2Seq2Choice' else None
    train_args = checkpoint['args']
    print(train_args)

    if model_name == 'Img2Seq2Choice':
        model = Img2Seq2Choice(msg_length=train_args.max_msg_len,
                               msg_vocsize=train_args.msg_vocsize,
                               hidden_size=train_args.hidden_size,
                               dropout=train_args.dropout_ratio,
                               msg_mode=train_args.msg_mode).to(
                                   torch.device('cpu'))
    elif model_name == 'Set2Seq2Seq':
        model = Set2Seq2Seq(voc.num_words,
                            msg_length=train_args.max_msg_len,
                            msg_vocsize=train_args.msg_vocsize,
                            hidden_size=train_args.hidden_size,
                            dropout=train_args.dropout_ratio,
                            msg_mode=train_args.msg_mode).to(
                                torch.device('cpu'))
    elif model_name == 'Set2Seq2Choice':
        model = Set2Seq2Choice(voc.num_words,
                               msg_length=train_args.max_msg_len,
                               msg_vocsize=train_args.msg_vocsize,
                               hidden_size=train_args.hidden_size,
                               dropout=train_args.dropout_ratio,
                               msg_mode=train_args.msg_mode).to(
                                   torch.device('cpu'))
    else:
        raise NotImplementedError

    model.load_state_dict(checkpoint['model'])
    model.eval()
    print('done')

    print('loading and building batch dataset...')
    if dataset_name == 'ImgChooseDataset':
        batch_set = ImgChooseDataset(dataset_dir_path=args.data_file,
                                     batch_size=1,
                                     device=torch.device('cpu'))
        in_set = [batch['correct']['label'][0] for batch in batch_set]
    elif dataset_name == 'FruitSeqDataset':
        batch_set = FruitSeqDataset(voc,
                                    dataset_file_path=args.data_file,
                                    batch_size=1,
                                    device=torch.device('cpu'))
        in_set = FruitSeqDataset.load_stringset(args.data_file)
    elif dataset_name == 'ChooseDataset':
        batch_set = ChooseDataset(voc,
                                  dataset_file_path=args.data_file,
                                  batch_size=1,
                                  device=torch.device('cpu'))
        in_set = FruitSeqDataset.load_stringset(args.data_file)
    print('done')

    build_listener_training_file(model, in_set, batch_set, out_file_path)