示例#1
0
def main():
    args = parser.parse_args()
    pp.pprint(vars(args))
    config = vars(args)

    # train with different datasets
    if args.dataset == 'oracle':
        oracle_model = OracleLstm(num_vocabulary=args.vocab_size, batch_size=args.batch_size, emb_dim=args.gen_emb_dim,
                                  hidden_dim=args.hidden_dim, sequence_length=args.seq_len,
                                  start_token=args.start_token)
        oracle_loader = OracleDataLoader(args.batch_size, args.seq_len)
        gen_loader = OracleDataLoader(args.batch_size, args.seq_len)

        generator = models.get_generator(args.g_architecture, vocab_size=args.vocab_size, batch_size=args.batch_size,
                                         seq_len=args.seq_len, gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots,
                                         head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim,
                                         start_token=args.start_token)
        discriminator = models.get_discriminator(args.d_architecture, batch_size=args.batch_size, seq_len=args.seq_len,
                                                 vocab_size=args.vocab_size, dis_emb_dim=args.dis_emb_dim,
                                                 num_rep=args.num_rep, sn=args.sn)
        oracle_train(generator, discriminator, oracle_model, oracle_loader, gen_loader, config)

    elif args.dataset in ['image_coco', 'emnlp_news']:
        data_file = os.path.join(args.data_dir, '{}.txt'.format(args.dataset))
        seq_len, vocab_size = text_precess(data_file)
        config['seq_len'] = seq_len # override the sequence length
        config['vocab_size'] = vocab_size
        print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))

        oracle_loader = RealDataLoader(args.batch_size, args.seq_len)

        generator = models.get_generator(args.g_architecture, vocab_size=vocab_size, batch_size=args.batch_size,
                                         seq_len=seq_len, gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots,
                                         head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim,
                                         start_token=args.start_token)
        discriminator = models.get_discriminator(args.d_architecture, batch_size=args.batch_size, seq_len=seq_len,
                                                 vocab_size=vocab_size, dis_emb_dim=args.dis_emb_dim,
                                                 num_rep=args.num_rep, sn=args.sn)
        f_classifier = models.get_classifier(args.f_architecture, scope="f_classifier", batch_size=args.batch_size, seq_len=seq_len,
                                                 vocab_size=vocab_size, dis_emb_dim=args.f_emb_dim,
                                                 num_rep=args.num_rep, sn=args.sn)
        real_train(generator, discriminator, f_classifier, oracle_loader, config)

    else:
        raise NotImplementedError('{}: unknown dataset!'.format(args.dataset))
示例#2
0
def main():
    args = parser.parse_args()
    pp.pprint(vars(args))
    config = vars(args)

    # train with different datasets
    if args.dataset == 'oracle':
        oracle_model = OracleLstm(num_vocabulary=args.vocab_size,
                                  batch_size=args.batch_size,
                                  emb_dim=args.gen_emb_dim,
                                  hidden_dim=args.hidden_dim,
                                  sequence_length=args.seq_len,
                                  start_token=args.start_token)
        oracle_loader = OracleDataLoader(args.batch_size, args.seq_len)
        gen_loader = OracleDataLoader(args.batch_size, args.seq_len)

        generator = models.get_generator(args.g_architecture,
                                         vocab_size=args.vocab_size,
                                         batch_size=args.batch_size,
                                         seq_len=args.seq_len,
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         start_token=args.start_token)
        discriminator = models.get_discriminator(args.d_architecture,
                                                 batch_size=args.batch_size,
                                                 seq_len=args.seq_len,
                                                 vocab_size=args.vocab_size,
                                                 dis_emb_dim=args.dis_emb_dim,
                                                 num_rep=args.num_rep,
                                                 sn=args.sn)
        oracle_train(generator, discriminator, oracle_model, oracle_loader,
                     gen_loader, config)

    elif args.dataset in ['image_coco', 'emnlp_news']:
        # custom dataset selected
        data_file = resources_path(args.data_dir,
                                   '{}.txt'.format(args.dataset))
        sample_dir = resources_path(config['sample_dir'])
        oracle_file = os.path.join(sample_dir,
                                   'oracle_{}.txt'.format(args.dataset))

        data_dir = resources_path(config['data_dir'])
        if args.dataset == 'image_coco':
            test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
        elif args.dataset == 'emnlp_news':
            test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
        else:
            raise NotImplementedError('Unknown dataset!')

        if args.dataset == 'emnlp_news':
            data_file, lda_file = create_subsample_data_file(data_file)
        else:
            lda_file = data_file

        seq_len, vocab_size, word_index_dict, index_word_dict = text_precess(
            data_file, test_file, oracle_file=oracle_file)
        config['seq_len'] = seq_len
        config['vocab_size'] = vocab_size
        print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))

        config['topic_loss_weight'] = args.topic_loss_weight

        if config['LSTM']:
            if config['topic']:
                topic_number = config['topic_number']
                oracle_loader = RealDataTopicLoader(args.batch_size,
                                                    args.seq_len)
                oracle_loader.set_dataset(args.dataset)
                oracle_loader.set_files(data_file, lda_file)
                oracle_loader.topic_num = topic_number
                oracle_loader.set_dictionaries(word_index_dict,
                                               index_word_dict)

                generator = models.get_generator(
                    args.g_architecture,
                    vocab_size=vocab_size,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    gen_emb_dim=args.gen_emb_dim,
                    mem_slots=args.mem_slots,
                    head_size=args.head_size,
                    num_heads=args.num_heads,
                    hidden_dim=args.hidden_dim,
                    start_token=args.start_token,
                    TopicInMemory=args.topic_in_memory,
                    NoTopic=args.no_topic)

                from real.real_gan.real_topic_train_NoDiscr import real_topic_train_NoDiscr
                real_topic_train_NoDiscr(generator, oracle_loader, config,
                                         args)
            else:
                generator = models.get_generator(args.g_architecture,
                                                 vocab_size=vocab_size,
                                                 batch_size=args.batch_size,
                                                 seq_len=seq_len,
                                                 gen_emb_dim=args.gen_emb_dim,
                                                 mem_slots=args.mem_slots,
                                                 head_size=args.head_size,
                                                 num_heads=args.num_heads,
                                                 hidden_dim=args.hidden_dim,
                                                 start_token=args.start_token)

                oracle_loader = RealDataLoader(args.batch_size, args.seq_len)
                oracle_loader.set_dictionaries(word_index_dict,
                                               index_word_dict)
                oracle_loader.set_dataset(args.dataset)
                oracle_loader.set_files(data_file, lda_file)
                oracle_loader.topic_num = config['topic_number']

                from real.real_gan.real_train_NoDiscr import real_train_NoDiscr
                real_train_NoDiscr(generator, oracle_loader, config, args)
        else:
            if config['topic']:
                topic_number = config['topic_number']
                oracle_loader = RealDataTopicLoader(args.batch_size,
                                                    args.seq_len)
                oracle_loader.set_dataset(args.dataset)
                oracle_loader.set_files(data_file, lda_file)
                oracle_loader.topic_num = topic_number
                oracle_loader.set_dictionaries(word_index_dict,
                                               index_word_dict)

                generator = models.get_generator(
                    args.g_architecture,
                    vocab_size=vocab_size,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    gen_emb_dim=args.gen_emb_dim,
                    mem_slots=args.mem_slots,
                    head_size=args.head_size,
                    num_heads=args.num_heads,
                    hidden_dim=args.hidden_dim,
                    start_token=args.start_token,
                    TopicInMemory=args.topic_in_memory,
                    NoTopic=args.no_topic)

                discriminator = models.get_discriminator(
                    args.d_architecture,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    vocab_size=vocab_size,
                    dis_emb_dim=args.dis_emb_dim,
                    num_rep=args.num_rep,
                    sn=args.sn)

                if not args.no_topic:
                    topic_discriminator = models.get_topic_discriminator(
                        args.topic_architecture,
                        batch_size=args.batch_size,
                        seq_len=seq_len,
                        vocab_size=vocab_size,
                        dis_emb_dim=args.dis_emb_dim,
                        num_rep=args.num_rep,
                        sn=args.sn,
                        discriminator=discriminator)
                else:
                    topic_discriminator = None
                from real.real_gan.real_topic_train import real_topic_train
                real_topic_train(generator, discriminator, topic_discriminator,
                                 oracle_loader, config, args)
            else:

                generator = models.get_generator(args.g_architecture,
                                                 vocab_size=vocab_size,
                                                 batch_size=args.batch_size,
                                                 seq_len=seq_len,
                                                 gen_emb_dim=args.gen_emb_dim,
                                                 mem_slots=args.mem_slots,
                                                 head_size=args.head_size,
                                                 num_heads=args.num_heads,
                                                 hidden_dim=args.hidden_dim,
                                                 start_token=args.start_token)

                discriminator = models.get_discriminator(
                    args.d_architecture,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    vocab_size=vocab_size,
                    dis_emb_dim=args.dis_emb_dim,
                    num_rep=args.num_rep,
                    sn=args.sn)

                oracle_loader = RealDataLoader(args.batch_size, args.seq_len)

                from real.real_gan.real_train import real_train
                real_train(generator, discriminator, oracle_loader, config,
                           args)

    elif args.dataset in ['Amazon_Attribute']:
        # custom dataset selected
        data_dir = resources_path(config['data_dir'], "Amazon_Attribute")
        sample_dir = resources_path(config['sample_dir'])
        oracle_file = os.path.join(sample_dir,
                                   'oracle_{}.txt'.format(args.dataset))
        train_file = os.path.join(data_dir, 'train.csv')
        dev_file = os.path.join(data_dir, 'dev.csv')
        test_file = os.path.join(data_dir, 'test.csv')

        # create_tokens_files(data_files=[train_file, dev_file, test_file])
        config_file = load_json(os.path.join(data_dir, 'config.json'))
        config = {**config, **config_file}  # merge dictionaries

        from real.real_gan.loaders.amazon_loader import RealDataAmazonLoader
        oracle_loader = RealDataAmazonLoader(args.batch_size, args.seq_len)
        oracle_loader.create_batches(
            data_file=[train_file, dev_file, test_file])
        oracle_loader.model_index_word_dict = load_json(
            join(data_dir, 'index_word_dict.json'))
        oracle_loader.model_word_index_dict = load_json(
            join(data_dir, 'word_index_dict.json'))

        generator = models.get_generator("amazon_attribute",
                                         vocab_size=config['vocabulary_size'],
                                         batch_size=args.batch_size,
                                         seq_len=config['seq_len'],
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         start_token=args.start_token,
                                         user_num=config['user_num'],
                                         product_num=config['product_num'],
                                         rating_num=5)

        discriminator = models.get_discriminator(
            "amazon_attribute",
            batch_size=args.batch_size,
            seq_len=config['seq_len'],
            vocab_size=config['vocabulary_size'],
            dis_emb_dim=args.dis_emb_dim,
            num_rep=args.num_rep,
            sn=args.sn)

        from real.real_gan.amazon_attribute_train import amazon_attribute_train
        amazon_attribute_train(generator, discriminator, oracle_loader, config,
                               args)
    elif args.dataset in ['CustomerReviews', 'imdb']:
        from real.real_gan.loaders.custom_reviews_loader import RealDataCustomerReviewsLoader
        from real.real_gan.customer_reviews_train import customer_reviews_train
        # custom dataset selected
        if args.dataset == 'CustomerReviews':
            data_dir = resources_path(config['data_dir'], "MovieReviews", "cr")
        elif args.dataset == 'imdb':
            data_dir = resources_path(config['data_dir'], "MovieReviews",
                                      'movie', 'sstb')
        else:
            raise ValueError
        sample_dir = resources_path(config['sample_dir'])
        oracle_file = os.path.join(sample_dir,
                                   'oracle_{}.txt'.format(args.dataset))
        train_file = os.path.join(data_dir, 'train.csv')

        # create_tokens_files(data_files=[train_file, dev_file, test_file])
        config_file = load_json(os.path.join(data_dir, 'config.json'))
        config = {**config, **config_file}  # merge dictionaries

        oracle_loader = RealDataCustomerReviewsLoader(args.batch_size,
                                                      args.seq_len)
        oracle_loader.create_batches(data_file=[train_file])
        oracle_loader.model_index_word_dict = load_json(
            join(data_dir, 'index_word_dict.json'))
        oracle_loader.model_word_index_dict = load_json(
            join(data_dir, 'word_index_dict.json'))

        generator = models.get_generator("CustomerReviews",
                                         vocab_size=config['vocabulary_size'],
                                         batch_size=args.batch_size,
                                         start_token=args.start_token,
                                         seq_len=config['seq_len'],
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         sentiment_num=config['sentiment_num'])

        discriminator_positive = models.get_discriminator(
            "CustomerReviews",
            scope="discriminator_positive",
            batch_size=args.batch_size,
            seq_len=config['seq_len'],
            vocab_size=config['vocabulary_size'],
            dis_emb_dim=args.dis_emb_dim,
            num_rep=args.num_rep,
            sn=args.sn)

        discriminator_negative = models.get_discriminator(
            "CustomerReviews",
            scope="discriminator_negative",
            batch_size=args.batch_size,
            seq_len=config['seq_len'],
            vocab_size=config['vocabulary_size'],
            dis_emb_dim=args.dis_emb_dim,
            num_rep=args.num_rep,
            sn=args.sn)

        customer_reviews_train(generator, discriminator_positive,
                               discriminator_negative, oracle_loader, config,
                               args)
    else:
        raise NotImplementedError('{}: unknown dataset!'.format(args.dataset))

    print("RUN FINISHED")
    return
示例#3
0
文件: run.py 项目: mahossam/OptiGAN
def main():
    args = parser.parse_args()
    # pp.pprint(vars(args))
    config = vars(args)

    # train with different datasets
    if args.dataset == 'oracle':
        oracle_model = OracleLstm(num_vocabulary=args.vocab_size,
                                  batch_size=args.batch_size,
                                  emb_dim=args.gen_emb_dim,
                                  hidden_dim=args.hidden_dim,
                                  sequence_length=args.seq_len,
                                  start_token=args.start_token)
        oracle_loader = OracleDataLoader(args.batch_size, args.seq_len)
        gen_loader = OracleDataLoader(args.batch_size, args.seq_len)

        generator = models.get_generator(args.g_architecture,
                                         vocab_size=args.vocab_size,
                                         batch_size=args.batch_size,
                                         seq_len=args.seq_len,
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         start_token=args.start_token)
        discriminator = models.get_discriminator(args.d_architecture,
                                                 batch_size=args.batch_size,
                                                 seq_len=args.seq_len,
                                                 vocab_size=args.vocab_size,
                                                 dis_emb_dim=args.dis_emb_dim,
                                                 num_rep=args.num_rep,
                                                 sn=args.sn)
        oracle_train(generator, discriminator, oracle_model, oracle_loader,
                     gen_loader, config)

    elif args.dataset in ['image_coco', 'emnlp_news', 'emnlp_news_small']:
        data_file = os.path.join(args.data_dir, '{}.txt'.format(args.dataset))
        seq_len, vocab_size, word_index_dict, index_word_dict = text_precess(
            data_file)
        config['seq_len'] = seq_len
        config['vocab_size'] = vocab_size
        # print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))

        oracle_loader = RealDataLoader(args.batch_size, args.seq_len)

        generator = models.get_generator(args.g_architecture,
                                         vocab_size=vocab_size,
                                         batch_size=args.batch_size,
                                         seq_len=seq_len,
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         start_token=args.start_token)
        discriminator = models.get_discriminator(args.d_architecture,
                                                 batch_size=args.batch_size,
                                                 seq_len=seq_len,
                                                 vocab_size=vocab_size,
                                                 dis_emb_dim=args.dis_emb_dim,
                                                 num_rep=args.num_rep,
                                                 sn=args.sn)

        # print("gen params = ", count_params(generator.trainable_variables))
        # print("disc params = ", count_params(discriminator.trainable_variables))
        # sys.stdout.flush()

        load_model = False
        if config['load_saved_model'] != "":
            log_dir_path = os.path.dirname(config['load_saved_model'])
            config['log_dir'] = log_dir_path
            config['sample_dir'] = os.path.join(
                os.path.split(log_dir_path)[0], 'samples')
            index_word_dict = load_index_to_word_dict(
                os.path.join(config['log_dir'], "index_to_word_dict.json"))
            word_index_dict = {v: k for k, v in index_word_dict.items()}
            load_model = True
        else:
            if not os.path.exists(config['log_dir']):
                os.makedirs(config['log_dir'])
            json.dump(
                index_word_dict,
                open(
                    os.path.join(config['log_dir'], "index_to_word_dict.json"),
                    'w'))
            json.dump(
                word_index_dict,
                open(
                    os.path.join(config['log_dir'], "word_to_index_dict.json"),
                    'w'))

        pp.pprint(config)
        print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))
        sys.stdout.flush()
        real_train(generator,
                   discriminator,
                   oracle_loader,
                   config,
                   word_index_dict,
                   index_word_dict,
                   load_model=load_model)

        if args.dataset == "emnlp_news" or args.dataset == "emnlp_news_small":
            call([
                "python", 'bleu_post_training_emnlp.py',
                os.path.join(os.path.split(config['log_dir'])[0], 'samples'),
                'na'
            ],
                 cwd=".")
        elif args.dataset == "image_coco":
            call([
                "python", 'bleu_post_training.py',
                os.path.join(os.path.split(config['log_dir'])[0], 'samples'),
                'na'
            ],
                 cwd=".")

    elif args.dataset in ['ace0_small']:
        # data_file = os.path.join(args.data_dir, '{}.txt'.format(args.dataset))
        # seq_len, vocab_size, word_index_dict, index_word_dict = text_precess(data_file)
        seq_len = config['seq_len']
        vocab_size = config['vocab_size']
        # # print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))

        # oracle_loader = RealDataLoader(args.batch_size, args.seq_len)

        generator = models.get_generator(args.g_architecture,
                                         vocab_size=config['vocab_size'],
                                         batch_size=args.batch_size,
                                         seq_len=config['seq_len'],
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         start_token=args.start_token)
        discriminator = models.get_discriminator(
            args.d_architecture,
            batch_size=args.batch_size,
            seq_len=config['seq_len'],
            vocab_size=config['vocab_size'],
            dis_emb_dim=args.dis_emb_dim,
            num_rep=args.num_rep,
            sn=args.sn)

        # print("gen params = ", count_params(generator.trainable_variables))
        # print("disc params = ", count_params(discriminator.trainable_variables))
        # sys.stdout.flush()

        load_model = False
        if config['load_saved_model'] != "":
            log_dir_path = os.path.dirname(config['load_saved_model'])
            config['log_dir'] = log_dir_path
            config['sample_dir'] = os.path.join(
                os.path.split(log_dir_path)[0], 'samples')
            index_word_dict = load_index_to_word_dict(
                os.path.join(config['log_dir'], "index_to_word_dict.json"))
            word_index_dict = {v: k for k, v in index_word_dict.items()}
            load_model = True
        else:
            if not os.path.exists(config['log_dir']):
                os.makedirs(config['log_dir'])
            # json.dump(index_word_dict, open(os.path.join(config['log_dir'], "index_to_word_dict.json"), 'w'))
            # json.dump(word_index_dict, open(os.path.join(config['log_dir'], "word_to_index_dict.json"), 'w'))

        pp.pprint(config)
        print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))
        sys.stdout.flush()
        real_train_traj(generator,
                        discriminator,
                        None,
                        config,
                        None,
                        None,
                        load_model=load_model)

        # if args.dataset == "emnlp_news" or args.dataset == "emnlp_news_small":
        #     call(["python", 'bleu_post_training_emnlp.py', os.path.join(os.path.split(config['log_dir'])[0], 'samples'), 'na'], cwd=".")
        # elif args.dataset == "image_coco":
        #     call(["python", 'bleu_post_training.py', os.path.join(os.path.split(config['log_dir'])[0], 'samples'), 'na'], cwd=".")
    else:
        raise NotImplementedError('{}: unknown dataset!'.format(args.dataset))