Exemple #1
0
    if args.show_dials:
        dials = cornell.load_dialogues(genre_filter=args.genre)
        for d_idx, dial in enumerate(dials):
            print("Dialog %d with %d phrases:" % (d_idx, len(dial)))
            for p in dial:
                print(" ".join(p))
            print()

    if args.show_train or args.show_dict_freq:
        phrase_pairs, emb_dict = data.load_data(genre_filter=args.genre)

    if args.show_train:
        rev_emb_dict = {idx: word for word, idx in emb_dict.items()}
        train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
        train_data = data.group_train_data(train_data)
        unk_token = emb_dict[data.UNKNOWN_TOKEN]

        print("Training pairs (%d total)" % len(train_data))
        train_data.sort(key=lambda p: len(p[1]), reverse=True)
        for idx, (p1, p2_group) in enumerate(train_data):
            w1 = data.decode_words(p1, rev_emb_dict)
            w2_group = [data.decode_words(p2, rev_emb_dict) for p2 in p2_group]
            print("%d:" % idx, " ".join(w1))
            for w2 in w2_group:
                print("%s:" % (" " * len(str(idx))), " ".join(w2))

    if args.show_dict_freq:
        words_stat = collections.Counter()
        for p1, p2 in phrase_pairs:
            words_stat.update(p1)
import torch

log = logging.getLogger("data_test")

if __name__ == "__main__":
    logging.basicConfig(format="%(asctime)-15s %(levelname)s %(message)s", level=logging.INFO)
    parser = argparse.ArgumentParser()
    parser.add_argument("--data", required=True,
                        help="Category to use for training. Empty string to train on full dataset")
    parser.add_argument("-m", "--model", required=True, help="Model name to load")
    args = parser.parse_args()

    phrase_pairs, emb_dict = data.load_data(args.data)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    train_data = data.group_train_data(train_data)
    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE)
    net.load_state_dict(torch.load(args.model))

    end_token = emb_dict[data.END_TOKEN]

    seq_count = 0
    sum_bleu = 0.0

    for seq_1, targets in train_data:
        input_seq = model.pack_input(seq_1, net.emb)
        enc = net.encode(input_seq)
        _, tokens = net.decode_chain_argmax(enc, input_seq.data[0:1],
                                            seq_len=data.MAX_TOKENS, stop_at_token=end_token)
    saves_path = os.path.join(SAVES_DIR, args.name)
    os.makedirs(saves_path, exist_ok=True)

    phrase_pairs, emb_dict = \
        data.load_data(genre_filter=args.data)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs),
             len(emb_dict))
    data.save_emb_dict(saves_path, emb_dict)
    end_token = emb_dict[data.END_TOKEN]
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    rand = np.random.RandomState(data.SHUFFLE_SEED)
    rand.shuffle(train_data)
    train_data, test_data = data.split_train_test(train_data)
    log.info("Training data converted, got %d samples", len(train_data))
    train_data = data.group_train_data(train_data)
    test_data = data.group_train_data(test_data)
    log.info("Train set has %d phrases, test %d", len(train_data),
             len(test_data))

    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM,
                            dict_size=len(emb_dict),
                            hid_size=model.HIDDEN_STATE_SIZE).to(device)
    log.info("Model: %s", net)

    writer = SummaryWriter(comment="-" + args.name)
    net.load_state_dict(torch.load(args.load))
    log.info("Model loaded from %s, continue "
             "training in RL mode...", args.load)
    args = parser.parse_args()
    device = torch.device("cuda" if args.cuda else "cpu")

    saves_path = os.path.join(SAVES_DIR, args.name)
    os.makedirs(saves_path, exist_ok=True)

    phrase_pairs, emb_dict = data.load_data(genre_filter=args.data)
    log.info("Obtained %d phrase pairs with %d uniq words", len(phrase_pairs), len(emb_dict))
    data.save_emb_dict(saves_path, emb_dict)
    end_token = emb_dict[data.END_TOKEN]
    train_data = data.encode_phrase_pairs(phrase_pairs, emb_dict)
    rand = np.random.RandomState(data.SHUFFLE_SEED)
    rand.shuffle(train_data)
    train_data, test_data = data.split_train_test(train_data)
    log.info("Training data converted, got %d samples", len(train_data))
    train_data = data.group_train_data(train_data)
    test_data = data.group_train_data(test_data)
    log.info("Train set has %d phrases, test %d", len(train_data), len(test_data))

    rev_emb_dict = {idx: word for word, idx in emb_dict.items()}

    net = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict),
                            hid_size=model.HIDDEN_STATE_SIZE).to(device)
    log.info("Model: %s", net)

    writer = SummaryWriter(comment="-" + args.name)
    net.load_state_dict(torch.load(args.load))
    log.info("Model loaded from %s, continue training in RL mode...", args.load)

    # BEGIN token
    beg_token = torch.LongTensor([emb_dict[data.BEGIN_TOKEN]]).to(device)