コード例 #1
0
def main(args):
    if args.debug:
        debug()

    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = LanguageGeneration.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath, ), args.cache_dir)
        vocab = data.vocab_list
        embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load(ez, vl),
                          (args.wvpath, args.embedding_size, vocab),
                          args.cache_dir, wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.vocab_list
        embed = wv.load(args.embedding_size, vocab)

    embed = np.array(embed, dtype=np.float32)

    with tf.Session(config=config) as sess:
        model = create_model(sess, data, args, embed)
        if args.mode == "train":
            model.train_process(sess, data, args)
        else:
            model.test_process(sess, data, args)
コード例 #2
0
ファイル: main.py プロジェクト: colinsongf/cotk_docs
def main(args):
    if args.debug:
        debug()
    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = LanguageGeneration.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath,), args.cache_dir)
        vocab = data.vocab_list
        embed = try_cache(lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
                          (args.wvpath, args.embedding_size, vocab),
                          args.cache_dir, wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.vocab_list
        embed = wv.load_matrix(args.embedding_size, vocab)

    embed = np.array(embed, dtype = np.float32)
    with tf.Session(config=config) as sess:
        generator, discriminator, rollout_gen = create_model(sess, data, args, embed)
        if args.mode == "train":
            if args.pre_train:
                #Start pretraining
                print('Start pre-training generator...')
                generator.pre_train_process(sess, data)

                print('Start pre-training discriminator...')
                discriminator.train_process(generator, data, sess, args.dis_pre_epoch_num)

                #Start testing
                generator.test_process(sess, data)

            #Start adversarial training
            for batch in range(args.total_adv_batch):
                print("Adversarial  training %d"%batch)
                print('Start adversarial training generator...')
                generator.adv_train_process(sess, data, rollout_gen, discriminator)
                testout = generator.pre_evaluate(sess, data, args.batch_size, "test")
                if (batch % args.test_per_epoch == 0 or batch == args.total_adv_batch - 1) and batch != 0:
                    print('total_batch: ', batch, 'test_loss: ', testout[0])
                    generator.test_process(sess, data) 

                print('Start adversarial training discriminator...')
                discriminator.train_process(generator, data, sess, args.dis_adv_epoch_num)
        else:
            print("Start testing...")
            test_res = generator.test_process(sess, data)
            for key, val in test_res.items():
                if isinstance(val, bytes):
                    test_res[key] = str(val)
            json.dump(test_res, open("./result.json", "w"))
コード例 #3
0
def main(args):
    if args.debug:
        debug()

    if args.cuda:
        config = tf.ConfigProto()
        config.gpu_options.allow_growth = True
    else:
        config = tf.ConfigProto(device_count={'GPU': 0})
        os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

    data_class = LanguageGeneration.load_class(args.dataset)
    wordvec_class = WordVector.load_class(args.wvclass)
    if wordvec_class == None:
        wordvec_class = Glove
    if args.cache:
        data = try_cache(data_class, (args.datapath, ), args.cache_dir)
        vocab = data.frequent_vocab_list
        embed = try_cache(
            lambda wv, ez, vl: wordvec_class(wv).load_matrix(ez, vl),
            (args.wvpath, args.embedding_size, vocab), args.cache_dir,
            wordvec_class.__name__)
    else:
        data = data_class(args.datapath)
        wv = wordvec_class(args.wvpath)
        vocab = data.frequent_vocab_list
        embed = wv.load_matrix(args.embedding_size, vocab)

    embed = np.array(embed, dtype=np.float32)

    with tf.Session(config=config) as sess:
        model = create_model(sess, data, args, embed)
        if args.mode == "train":
            model.train_process(sess, data, args)
        else:
            test_res = model.test_process(sess, data, args)

            for key, val in test_res.items():
                if isinstance(val, bytes):
                    test_res[key] = str(val)
            json.dump(test_res, open("./result.json", "w"))
コード例 #4
0
    def base_test_init(self, dl):
        with pytest.raises(ValueError):
            LanguageGeneration("./tests/dataloader/dummy_mscoco#MSCOCO",
                               pretrained='none')
        with pytest.raises(ValueError):
            LanguageGeneration("./tests/dataloader/dummy_mscoco#MSCOCO",
                               pretrained='gpt2')
        with pytest.raises(ValueError):
            LanguageGeneration("./tests/dataloader/dummy_mscoco#MSCOCO",
                               pretrained='bert')

        assert isinstance(dl, LanguageGeneration)
        assert isinstance(dl.file_id, str)
        assert isinstance(dl.file_path, str)
        for set_name, fields in dl.fields.items():
            assert isinstance(set_name, str)
            assert isinstance(fields, dict)
            for field_name, field in fields.items():
                assert isinstance(field_name, str)
                assert isinstance(field, Field)

        assert isinstance(dl.vocabs, list)
        for vocab in dl.vocabs:
            assert isinstance(vocab, Vocab)
        assert isinstance(dl.tokenizers, list)
        for toker in dl.tokenizers:
            assert isinstance(toker, Tokenizer)

        for (_, data), (_, index) in zip(dl.data.items(), dl.index.items()):
            assert isinstance(data, dict)
            assert isinstance(index, list)
            for field_name, content in data.items():
                assert isinstance(content, dict)
                for _, each_content in content.items():
                    assert isinstance(each_content, list)
                    assert len(index) == len(each_content)
        for _, batch_id in dl.batch_id.items():
            assert batch_id == 0
        for _, batch_size in dl.batch_size.items():
            assert batch_size is None

        assert isinstance(dl.frequent_vocab_list, list)
        assert dl.frequent_vocab_size == len(dl.frequent_vocab_list)
        assert isinstance(dl.all_vocab_list, list)
        assert dl.all_vocab_size == len(dl.all_vocab_list)
        assert dl.all_vocab_size >= dl.frequent_vocab_size

        for _, data in dl.data.items():
            sent = data['sent']
            ids = sent['id']
            assert isinstance(ids, list)
            assert isinstance(ids[0], list)
            if dl._pretrained is None or dl._pretrained == "gpt2":
                assert ids[0][0] == dl.go_id
                assert ids[0][-1] == dl.eos_id
            else:
                assert ids[0][0] == dl.get_special_tokens_id("cls")
                assert ids[0][-1] == dl.get_special_tokens_id("sep")
            strs = sent['str']
            assert isinstance(strs, list)
            assert isinstance(strs[0], str)

        with pytest.raises(TypeError):
            LanguageGeneration()