def base_test_init(self, dl: MultiTurnDialog): with pytest.raises(ValueError): MultiTurnDialog("./tests/dataloader/dummy_switchboardcorpus#SwitchboardCorpus", pretrained='none') with pytest.raises(ValueError): MultiTurnDialog("./tests/dataloader/dummy_switchboardcorpus#SwitchboardCorpus", pretrained='gpt2') with pytest.raises(ValueError): MultiTurnDialog("./tests/dataloader/dummy_switchboardcorpus#SwitchboardCorpus", pretrained='bert') assert isinstance(dl, MultiTurnDialog) super().base_test_init(dl) assert dl.default_field_name is not None and dl.default_field_set_name is not None default_field = dl.get_default_field() assert isinstance(default_field, Session) assert isinstance(dl.frequent_vocab_list, list) assert dl.frequent_vocab_size == len(dl.frequent_vocab_list) assert isinstance(dl.all_vocab_list, list) assert dl.all_vocab_size == len(dl.all_vocab_list) assert dl.all_vocab_size > 4 assert dl.all_vocab_size >= dl.frequent_vocab_size
def main(args): if args.debug: debug() if args.cuda: config = tf.ConfigProto() config.gpu_options.allow_growth = True else: config = tf.ConfigProto(device_count={'GPU': 0}) os.environ["CUDA_VISIBLE_DEVICES"] = "-1" data_class = MultiTurnDialog.load_class(args.dataset) wordvec_class = WordVector.load_class(args.wvclass) if wordvec_class == None: wordvec_class = Glove if args.cache: data = try_cache(data_class, (args.datapath, ), args.cache_dir) vocab = data.frequent_vocab_list embed = try_cache( lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl), (args.wvpath, args.word_embedding_size, vocab), args.cache_dir, wordvec_class.__name__) word2vec = try_cache( lambda wv, ez, vl: wordvec_class(wv).load_dict(vl), (args.wvpath, args.word_embedding_size, vocab), args.cache_dir, wordvec_class.__name__) else: data = data_class( args.datapath, min_frequent_vocab_times=args.min_frequent_vocab_times, max_sent_length=args.max_sent_length, max_turn_length=args.max_turn_length) wv = wordvec_class(args.wvpath) vocab = data.frequent_vocab_list #dim:9508 embed = wv.load_matrix(args.word_embedding_size, vocab) word2vec = wv.load_dict(vocab) embed = np.array(embed, dtype=np.float32) with tf.Session(config=config) as sess: model = create_model(sess, data, args, embed) if args.mode == "train": model.train_process(sess, data, args) else: multi_ref_res = model.test_multi_ref(sess, data, word2vec, args) test_res = model.test_process(sess, data, args) test_res.update(multi_ref_res) for key, val in test_res.items(): if isinstance(val, bytes): test_res[key] = str(val) json.dump(test_res, open("./result.json", "w"))
def main(args): if args.debug: debug() if args.cuda: config = tf.ConfigProto() config.gpu_options.allow_growth = True else: config = tf.ConfigProto(device_count={'GPU': 0}) os.environ["CUDA_VISIBLE_DEVICES"] = "-1" data_class = MultiTurnDialog.load_class(args.dataset) wordvec_class = WordVector.load_class(args.wvclass) if wordvec_class == None: wordvec_class = Glove if args.cache: data = try_cache(data_class, (args.datapath, ), args.cache_dir) vocab = data.vocab_list embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load(ez, vl), (args.wvpath, args.word_embedding_size, vocab), args.cache_dir, wordvec_class.__name__) else: data = data_class(args.datapath, min_vocab_times=args.min_vocab_times, max_sen_length=args.max_sen_length, max_turn_length=args.max_turn_length) wv = wordvec_class(args.wvpath) vocab = data.vocab_list embed = wv.load(args.word_embedding_size, vocab) embed = np.array(embed, dtype=np.float32) with tf.Session(config=config) as sess: model = create_model(sess, data, args, embed) if args.mode == "train": model.train_process(sess, data, args) else: model.test_multi_ref(sess, data, embed, args) model.test_process(sess, data, args)
def base_test_multi_turn_convert(self, dl: MultiTurnDialog): sent_id = [[0, 1, 2], [2, 1, 1]] sent = [["<pad>", "<unk>", "<go>"], ["<go>", "<unk>", "<unk>"]] assert sent == dl.convert_multi_turn_ids_to_tokens(sent_id, remove_special=False) assert sent_id == dl.convert_multi_turn_tokens_to_ids(sent) sent = [["<unk>", "<go>", "<pad>", "<unkownword>", "<pad>", "<go>"], ["<go>", "<eos>"]] sent_id = [[1, 2, 0, 1, 0, 2], [2, 3]] assert sent_id == dl.convert_multi_turn_tokens_to_ids(sent) sent_id = [[0, 1, 2, 2, 0, 3, 1, 0, 0], [0, 3, 2], [1, 2, 2, 0], [1, 2, 2, 3]] sent = [["<pad>", "<unk>", "<go>", "<go>", "<pad>", "<eos>", "<unk>", "<pad>", "<pad>"], \ ["<pad>", "<eos>", "<go>"], \ ["<unk>", "<go>", "<go>", "<pad>"], \ ["<unk>", "<go>", "<go>", "<eos>"]] assert sent == dl.convert_multi_turn_ids_to_tokens(sent_id, remove_special=False, trim=False) sent = [["<pad>", "<unk>", "<go>", "<go>"]] sent_id = [[0, 1, 2, 2]] assert sent == dl.convert_multi_turn_ids_to_tokens(sent_id) sent = [[dl.all_vocab_list[dl.unk_id]]] assert [[dl.unk_id]] == dl.convert_multi_turn_tokens_to_ids(sent) assert [[dl.unk_id]] == dl.convert_multi_turn_tokens_to_ids(sent, only_frequent_word=True)