Exemplo n.º 1
0
def main():
    args = parser.parse_args()
    pp.pprint(vars(args))
    config = vars(args)
    model_path = resources_path(os.path.join('trained_models', config['model_name']))
    input_path = resources_path(os.path.join('inference_data', config['input_name']))

    data_file = resources_path(args.data_dir, '{}.txt'.format(args.dataset))
    sample_dir = resources_path(config['sample_dir'])
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(args.dataset))

    if args.dataset == 'emnlp_news' :
        data_file, lda_file = create_subsample_data_file(data_file)
    else:
        lda_file = data_file

    seq_len, vocab_size, word_index_dict, index_word_dict = text_precess(data_file, oracle_file=oracle_file)
    print(index_word_dict)
    config['seq_len'] = seq_len
    config['vocab_size'] = vocab_size
    print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))

    if config['topic']:
        topic_number = config['topic_number']
        oracle_loader = RealDataTopicLoader(args.batch_size, args.seq_len)
        oracle_loader.set_dataset(args.dataset)
        oracle_loader.topic_num = topic_number
        oracle_loader.set_dictionaries(word_index_dict, index_word_dict)
        oracle_loader.get_LDA(word_index_dict, index_word_dict, data_file)
        print(oracle_loader.model_index_word_dict)
        inference_main(oracle_loader, config, model_path, input_path)
Exemplo n.º 2
0
def get_corpus(coco=True, datapath=None) -> List[str]:
    if datapath is None:
        if coco:
            fname = resources_path("data", "image_coco.txt")
        else:
            fname = resources_path("data", "emnlp_news.txt")
    else:
        fname = datapath

    with open(fname) as f:
        lines = [line.rstrip('\n') for line in f]
    return lines
Exemplo n.º 3
0
def get_perc_sent_topic(lda, topic_num, data_file):
    file_path = "{}-{}-{}".format(lda.lda_model, topic_num, data_file[-10:])
    file_path = resources_path("topic_models", file_path)
    try:
        sent_topics_df = load_pickle(file_path)
    except FileNotFoundError:
        print("get perc sent topic not found")
        ldamodel = lda.lda_model
        with open(data_file) as f:
            sentences = [line.rstrip('\n') for line in f]

        tmp = process_texts(sentences, lda.stops)
        corpus_bow = [lda.dictionary.doc2bow(i) for i in tmp]
        # Init output
        sent_topics_df = pd.DataFrame()

        # Get main topic in each document
        for i, row_list in enumerate(tqdm(ldamodel[corpus_bow])):
            row = row_list[0] if ldamodel.per_word_topics else row_list
            # print(row)
            # row = sorted(row, key=lambda x: (x[1]), reverse=True) # sort list to get dominant topic
            # Get the Dominant topic, Perc Contribution and Keywords for each document
            to_append = np.zeros(topic_num)
            for j, (topic_n, prop_topic) in enumerate(row):
                to_append[topic_n] = prop_topic
            sent_topics_df = sent_topics_df.append(pd.Series(to_append),
                                                   ignore_index=True)

        sent_topics_df.columns = [
            "Topic {}".format(topic_number)
            for topic_number in range(len(to_append))
        ]

        # Add original text to the end of the output
        contents = pd.Series(sentences)
        sent_topics_df = pd.concat([sent_topics_df, contents], axis=1)
        sent_topics_df = sent_topics_df.reset_index()

        write_pickle(file_path, sent_topics_df)

    return sent_topics_df
Exemplo n.º 4
0
def main():
    args = parser.parse_args()
    pp.pprint(vars(args))
    config = vars(args)

    # train with different datasets
    if args.dataset == 'oracle':
        oracle_model = OracleLstm(num_vocabulary=args.vocab_size,
                                  batch_size=args.batch_size,
                                  emb_dim=args.gen_emb_dim,
                                  hidden_dim=args.hidden_dim,
                                  sequence_length=args.seq_len,
                                  start_token=args.start_token)
        oracle_loader = OracleDataLoader(args.batch_size, args.seq_len)
        gen_loader = OracleDataLoader(args.batch_size, args.seq_len)

        generator = models.get_generator(args.g_architecture,
                                         vocab_size=args.vocab_size,
                                         batch_size=args.batch_size,
                                         seq_len=args.seq_len,
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         start_token=args.start_token)
        discriminator = models.get_discriminator(args.d_architecture,
                                                 batch_size=args.batch_size,
                                                 seq_len=args.seq_len,
                                                 vocab_size=args.vocab_size,
                                                 dis_emb_dim=args.dis_emb_dim,
                                                 num_rep=args.num_rep,
                                                 sn=args.sn)
        oracle_train(generator, discriminator, oracle_model, oracle_loader,
                     gen_loader, config)

    elif args.dataset in ['image_coco', 'emnlp_news']:
        # custom dataset selected
        data_file = resources_path(args.data_dir,
                                   '{}.txt'.format(args.dataset))
        sample_dir = resources_path(config['sample_dir'])
        oracle_file = os.path.join(sample_dir,
                                   'oracle_{}.txt'.format(args.dataset))

        data_dir = resources_path(config['data_dir'])
        if args.dataset == 'image_coco':
            test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
        elif args.dataset == 'emnlp_news':
            test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
        else:
            raise NotImplementedError('Unknown dataset!')

        if args.dataset == 'emnlp_news':
            data_file, lda_file = create_subsample_data_file(data_file)
        else:
            lda_file = data_file

        seq_len, vocab_size, word_index_dict, index_word_dict = text_precess(
            data_file, test_file, oracle_file=oracle_file)
        config['seq_len'] = seq_len
        config['vocab_size'] = vocab_size
        print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size))

        config['topic_loss_weight'] = args.topic_loss_weight

        if config['LSTM']:
            if config['topic']:
                topic_number = config['topic_number']
                oracle_loader = RealDataTopicLoader(args.batch_size,
                                                    args.seq_len)
                oracle_loader.set_dataset(args.dataset)
                oracle_loader.set_files(data_file, lda_file)
                oracle_loader.topic_num = topic_number
                oracle_loader.set_dictionaries(word_index_dict,
                                               index_word_dict)

                generator = models.get_generator(
                    args.g_architecture,
                    vocab_size=vocab_size,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    gen_emb_dim=args.gen_emb_dim,
                    mem_slots=args.mem_slots,
                    head_size=args.head_size,
                    num_heads=args.num_heads,
                    hidden_dim=args.hidden_dim,
                    start_token=args.start_token,
                    TopicInMemory=args.topic_in_memory,
                    NoTopic=args.no_topic)

                from real.real_gan.real_topic_train_NoDiscr import real_topic_train_NoDiscr
                real_topic_train_NoDiscr(generator, oracle_loader, config,
                                         args)
            else:
                generator = models.get_generator(args.g_architecture,
                                                 vocab_size=vocab_size,
                                                 batch_size=args.batch_size,
                                                 seq_len=seq_len,
                                                 gen_emb_dim=args.gen_emb_dim,
                                                 mem_slots=args.mem_slots,
                                                 head_size=args.head_size,
                                                 num_heads=args.num_heads,
                                                 hidden_dim=args.hidden_dim,
                                                 start_token=args.start_token)

                oracle_loader = RealDataLoader(args.batch_size, args.seq_len)
                oracle_loader.set_dictionaries(word_index_dict,
                                               index_word_dict)
                oracle_loader.set_dataset(args.dataset)
                oracle_loader.set_files(data_file, lda_file)
                oracle_loader.topic_num = config['topic_number']

                from real.real_gan.real_train_NoDiscr import real_train_NoDiscr
                real_train_NoDiscr(generator, oracle_loader, config, args)
        else:
            if config['topic']:
                topic_number = config['topic_number']
                oracle_loader = RealDataTopicLoader(args.batch_size,
                                                    args.seq_len)
                oracle_loader.set_dataset(args.dataset)
                oracle_loader.set_files(data_file, lda_file)
                oracle_loader.topic_num = topic_number
                oracle_loader.set_dictionaries(word_index_dict,
                                               index_word_dict)

                generator = models.get_generator(
                    args.g_architecture,
                    vocab_size=vocab_size,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    gen_emb_dim=args.gen_emb_dim,
                    mem_slots=args.mem_slots,
                    head_size=args.head_size,
                    num_heads=args.num_heads,
                    hidden_dim=args.hidden_dim,
                    start_token=args.start_token,
                    TopicInMemory=args.topic_in_memory,
                    NoTopic=args.no_topic)

                discriminator = models.get_discriminator(
                    args.d_architecture,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    vocab_size=vocab_size,
                    dis_emb_dim=args.dis_emb_dim,
                    num_rep=args.num_rep,
                    sn=args.sn)

                if not args.no_topic:
                    topic_discriminator = models.get_topic_discriminator(
                        args.topic_architecture,
                        batch_size=args.batch_size,
                        seq_len=seq_len,
                        vocab_size=vocab_size,
                        dis_emb_dim=args.dis_emb_dim,
                        num_rep=args.num_rep,
                        sn=args.sn,
                        discriminator=discriminator)
                else:
                    topic_discriminator = None
                from real.real_gan.real_topic_train import real_topic_train
                real_topic_train(generator, discriminator, topic_discriminator,
                                 oracle_loader, config, args)
            else:

                generator = models.get_generator(args.g_architecture,
                                                 vocab_size=vocab_size,
                                                 batch_size=args.batch_size,
                                                 seq_len=seq_len,
                                                 gen_emb_dim=args.gen_emb_dim,
                                                 mem_slots=args.mem_slots,
                                                 head_size=args.head_size,
                                                 num_heads=args.num_heads,
                                                 hidden_dim=args.hidden_dim,
                                                 start_token=args.start_token)

                discriminator = models.get_discriminator(
                    args.d_architecture,
                    batch_size=args.batch_size,
                    seq_len=seq_len,
                    vocab_size=vocab_size,
                    dis_emb_dim=args.dis_emb_dim,
                    num_rep=args.num_rep,
                    sn=args.sn)

                oracle_loader = RealDataLoader(args.batch_size, args.seq_len)

                from real.real_gan.real_train import real_train
                real_train(generator, discriminator, oracle_loader, config,
                           args)

    elif args.dataset in ['Amazon_Attribute']:
        # custom dataset selected
        data_dir = resources_path(config['data_dir'], "Amazon_Attribute")
        sample_dir = resources_path(config['sample_dir'])
        oracle_file = os.path.join(sample_dir,
                                   'oracle_{}.txt'.format(args.dataset))
        train_file = os.path.join(data_dir, 'train.csv')
        dev_file = os.path.join(data_dir, 'dev.csv')
        test_file = os.path.join(data_dir, 'test.csv')

        # create_tokens_files(data_files=[train_file, dev_file, test_file])
        config_file = load_json(os.path.join(data_dir, 'config.json'))
        config = {**config, **config_file}  # merge dictionaries

        from real.real_gan.loaders.amazon_loader import RealDataAmazonLoader
        oracle_loader = RealDataAmazonLoader(args.batch_size, args.seq_len)
        oracle_loader.create_batches(
            data_file=[train_file, dev_file, test_file])
        oracle_loader.model_index_word_dict = load_json(
            join(data_dir, 'index_word_dict.json'))
        oracle_loader.model_word_index_dict = load_json(
            join(data_dir, 'word_index_dict.json'))

        generator = models.get_generator("amazon_attribute",
                                         vocab_size=config['vocabulary_size'],
                                         batch_size=args.batch_size,
                                         seq_len=config['seq_len'],
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         start_token=args.start_token,
                                         user_num=config['user_num'],
                                         product_num=config['product_num'],
                                         rating_num=5)

        discriminator = models.get_discriminator(
            "amazon_attribute",
            batch_size=args.batch_size,
            seq_len=config['seq_len'],
            vocab_size=config['vocabulary_size'],
            dis_emb_dim=args.dis_emb_dim,
            num_rep=args.num_rep,
            sn=args.sn)

        from real.real_gan.amazon_attribute_train import amazon_attribute_train
        amazon_attribute_train(generator, discriminator, oracle_loader, config,
                               args)
    elif args.dataset in ['CustomerReviews', 'imdb']:
        from real.real_gan.loaders.custom_reviews_loader import RealDataCustomerReviewsLoader
        from real.real_gan.customer_reviews_train import customer_reviews_train
        # custom dataset selected
        if args.dataset == 'CustomerReviews':
            data_dir = resources_path(config['data_dir'], "MovieReviews", "cr")
        elif args.dataset == 'imdb':
            data_dir = resources_path(config['data_dir'], "MovieReviews",
                                      'movie', 'sstb')
        else:
            raise ValueError
        sample_dir = resources_path(config['sample_dir'])
        oracle_file = os.path.join(sample_dir,
                                   'oracle_{}.txt'.format(args.dataset))
        train_file = os.path.join(data_dir, 'train.csv')

        # create_tokens_files(data_files=[train_file, dev_file, test_file])
        config_file = load_json(os.path.join(data_dir, 'config.json'))
        config = {**config, **config_file}  # merge dictionaries

        oracle_loader = RealDataCustomerReviewsLoader(args.batch_size,
                                                      args.seq_len)
        oracle_loader.create_batches(data_file=[train_file])
        oracle_loader.model_index_word_dict = load_json(
            join(data_dir, 'index_word_dict.json'))
        oracle_loader.model_word_index_dict = load_json(
            join(data_dir, 'word_index_dict.json'))

        generator = models.get_generator("CustomerReviews",
                                         vocab_size=config['vocabulary_size'],
                                         batch_size=args.batch_size,
                                         start_token=args.start_token,
                                         seq_len=config['seq_len'],
                                         gen_emb_dim=args.gen_emb_dim,
                                         mem_slots=args.mem_slots,
                                         head_size=args.head_size,
                                         num_heads=args.num_heads,
                                         hidden_dim=args.hidden_dim,
                                         sentiment_num=config['sentiment_num'])

        discriminator_positive = models.get_discriminator(
            "CustomerReviews",
            scope="discriminator_positive",
            batch_size=args.batch_size,
            seq_len=config['seq_len'],
            vocab_size=config['vocabulary_size'],
            dis_emb_dim=args.dis_emb_dim,
            num_rep=args.num_rep,
            sn=args.sn)

        discriminator_negative = models.get_discriminator(
            "CustomerReviews",
            scope="discriminator_negative",
            batch_size=args.batch_size,
            seq_len=config['seq_len'],
            vocab_size=config['vocabulary_size'],
            dis_emb_dim=args.dis_emb_dim,
            num_rep=args.num_rep,
            sn=args.sn)

        customer_reviews_train(generator, discriminator_positive,
                               discriminator_negative, oracle_loader, config,
                               args)
    else:
        raise NotImplementedError('{}: unknown dataset!'.format(args.dataset))

    print("RUN FINISHED")
    return
Exemplo n.º 5
0
def real_topic_train_NoDiscr(generator, oracle_loader: RealDataTopicLoader,
                             config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'])
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt')
    json_file = os.path.join(sample_dir,
                             'json_file{}.txt'.format(args.json_file))
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))
    if dataset == 'image_coco':
        test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
    elif dataset == 'emnlp_news':
        test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
    else:
        raise NotImplementedError('Unknown dataset!')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = placeholder(tf.int32, [batch_size, seq_len],
                         name="x_real")  # tokens of oracle sequences
    x_topic = placeholder(tf.float32,
                          [batch_size, oracle_loader.vocab_size + 1],
                          name="x_topic")  # todo stessa cosa del +1
    x_topic_random = placeholder(tf.float32,
                                 [batch_size, oracle_loader.vocab_size + 1],
                                 name="x_topic_random")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    x_fake_onehot_appr, x_fake, g_pretrain_loss, gen_o, \
    lambda_values_returned, gen_x_no_lambda = generator(x_real=x_real,
                                                        temperature=temperature,
                                                        x_topic=x_topic)

    # A function to calculate the gradients and get training operations
    def get_train_ops(config, g_pretrain_loss):
        gpre_lr = config['gpre_lr']

        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='generator')
        grad_clip = 5.0  # keep the same with the previous setting

        # generator pre-training
        pretrain_opt = tf.train.AdamOptimizer(gpre_lr,
                                              beta1=0.9,
                                              beta2=0.999,
                                              name="gen_pretrain_adam")
        pretrain_grad, _ = tf.clip_by_global_norm(
            tf.gradients(g_pretrain_loss, g_vars, name="gradients_g_pretrain"),
            grad_clip,
            name="g_pretrain_clipping")  # gradient clipping
        g_pretrain_op = pretrain_opt.apply_gradients(zip(
            pretrain_grad, g_vars))

        return g_pretrain_op

    # Train ops
    g_pretrain_op = get_train_ops(config, g_pretrain_loss)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary, run_information
    ]

    # ------------- initial the graph --------------
    with init_sess() as sess:
        variables_dict = get_parameters_division()

        log = open(csv_file, 'w')
        sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'),
                                           sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)

        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, test_file, gen_text_file,
                              g_pretrain_loss, x_real, x_topic, sess,
                              json_file)

        gc.collect()
        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')

        progress = tqdm(range(npre_epochs))
        for epoch in progress:
            # pre-training
            g_pretrain_loss_np = pre_train_epoch(sess, g_pretrain_op,
                                                 g_pretrain_loss, x_real,
                                                 oracle_loader, x_topic)
            gen_pretrain_loss_summary.write_summary(g_pretrain_loss_np, epoch)
            progress.set_description(
                "Pretrain_loss: {}".format(g_pretrain_loss_np))

            # Test
            ntest_pre = 40
            if np.mod(epoch, ntest_pre) == 0:
                json_object = generate_sentences(sess,
                                                 x_fake,
                                                 batch_size,
                                                 num_sentences,
                                                 oracle_loader=oracle_loader,
                                                 x_topic=x_topic)
                write_json(json_file, json_object)

                with open(gen_text_file, 'w') as outfile:
                    i = 0
                    for sent in json_object['sentences']:
                        if i < 200:
                            outfile.write(sent['generated_sentence'] + "\n")
                        else:
                            break

                # take sentences from saved files
                sent = take_sentences_json(json_object,
                                           first_elem='generated_sentence',
                                           second_elem='real_starting')
                gen_sentences_summary.write_summary(sent, epoch)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str, epoch)

                msg = 'pre_gen_epoch:' + str(
                    epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

                gc.collect()

        gc.collect()
        sum_writer.close()

        save_model = True
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={"x_topic": x_topic},
                        outputs={"gen_x": x_fake})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)
def amazon_attribute_train(generator: rmc_att_topic.generator,
                           discriminator: rmc_att_topic.discriminator,
                           oracle_loader: RealDataAmazonLoader, config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocabulary_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    n_topic_pre_epochs = config['n_topic_pre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'], "Amazon_Attribute")
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt')
    json_file = os.path.join(sample_dir, 'json_file.txt')
    json_file_validation = os.path.join(sample_dir, 'json_file_validation.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))

    test_file = os.path.join(data_dir, 'test.csv')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len],
                            name="x_real")  # tokens of oracle sequences
    x_user = tf.placeholder(tf.int32, [batch_size], name="x_user")
    x_product = tf.placeholder(tf.int32, [batch_size], name="x_product")
    x_rating = tf.placeholder(tf.int32, [batch_size], name="x_rating")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    generator_obj = generator(x_real=x_real,
                              temperature=temperature,
                              x_user=x_user,
                              x_product=x_product,
                              x_rating=x_rating)
    discriminator_real = discriminator(
        x_onehot=x_real_onehot)  # , with_out=False)
    discriminator_fake = discriminator(
        x_onehot=generator_obj.gen_x_onehot_adv)  # , with_out=False)

    # GAN / Divergence type
    log_pg, g_loss, d_loss = get_losses(generator_obj, discriminator_real,
                                        discriminator_fake, config)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops(
        config, generator_obj.pretrain_loss, g_loss, d_loss, None, log_pg,
        temperature, global_step)

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('adv_loss/discriminator/total', d_loss),
        tf.summary.scalar('adv_loss/generator/total_g_loss', g_loss),
        tf.summary.scalar('adv_loss/log_pg', log_pg),
        tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('adv_loss/temperature', temperature),
    ]
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    config['bleu_amazon'] = True
    config['bleu_amazon_validation'] = True
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    topic_discr_pretrain_summary = CustomSummary(name='pretrain_loss',
                                                 scope='topic_discriminator')
    topic_discr_accuracy_summary = CustomSummary(name='pretrain_accuracy',
                                                 scope='topic_discriminator')
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary,
        topic_discr_pretrain_summary, topic_discr_accuracy_summary,
        run_information
    ]

    # To save the trained model
    saver = tf.train.Saver()
    # ------------- initial the graph --------------
    with init_sess() as sess:
        variables_dict = get_parameters_division()

        log = open(csv_file, 'w')
        sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'),
                                           sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)
        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        metrics = get_metrics(config, oracle_loader, sess, json_file,
                              json_file_validation, generator_obj)

        gc.collect()
        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')
            # pre-training
            # Pre-train the generator using MLE for one epoch

            progress = tqdm(range(npre_epochs))
            for epoch in progress:
                g_pretrain_loss_np = generator_obj.pretrain_epoch(
                    oracle_loader, sess, g_pretrain_op=g_pretrain_op)
                gen_pretrain_loss_summary.write_summary(
                    g_pretrain_loss_np, epoch)

                # Test
                ntest_pre = 40
                if np.mod(epoch, ntest_pre) == 0:
                    generator_obj.generated_num = 200
                    json_object = generator_obj.generate_samples(
                        sess, oracle_loader, dataset="train")
                    write_json(json_file, json_object)
                    json_object = generator_obj.generate_samples(
                        sess, oracle_loader, dataset="validation")
                    write_json(json_file_validation, json_object)

                    # take sentences from saved files
                    sent = generator.get_sentences(json_object)
                    sent = take_sentences_attribute(json_object)
                    gen_sentences_summary.write_summary(sent, epoch)

                    # write summaries
                    scores = [metric.get_score() for metric in metrics]
                    metrics_summary_str = sess.run(metric_summary_op,
                                                   feed_dict=dict(
                                                       zip(metrics_pl,
                                                           scores)))
                    sum_writer.add_summary(metrics_summary_str, epoch)

                    msg = 'pre_gen_epoch:' + str(
                        epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                    metric_names = [metric.get_name() for metric in metrics]
                    for (name, score) in zip(metric_names, scores):
                        msg += ', ' + name + ': %.4f' % score
                    tqdm.write(msg)
                    log.write(msg)
                    log.write('\n')

                    gc.collect()

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for adv_epoch in progress:
            gc.collect()
            niter = sess.run(global_step)

            t0 = time.time()
            # Adversarial training
            for _ in range(config['gsteps']):
                user, product, rating, sentence = oracle_loader.random_batch(
                    dataset="train")
                feed_dict = {
                    generator_obj.x_user: user,
                    generator_obj.x_product: product,
                    generator_obj.x_rating: rating
                }
                sess.run(g_train_op, feed_dict=feed_dict)
            for _ in range(config['dsteps']):
                user, product, rating, sentence = oracle_loader.random_batch(
                    dataset="train")
                n = np.zeros((batch_size, seq_len))
                for ind, el in enumerate(sentence):
                    n[ind] = el
                feed_dict = {
                    generator_obj.x_user: user,
                    generator_obj.x_product: product,
                    generator_obj.x_rating: rating,
                    x_real: n
                }
                sess.run(d_train_op, feed_dict=feed_dict)

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps,
                                                adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            user, product, rating, sentence = oracle_loader.random_batch(
                dataset="train")
            n = np.zeros((batch_size, seq_len))
            for ind, el in enumerate(sentence):
                n[ind] = el
            feed_dict = {
                generator_obj.x_user: user,
                generator_obj.x_product: product,
                generator_obj.x_rating: rating,
                x_real: n
            }
            g_loss_np, d_loss_np, loss_summary_str = sess.run(
                [g_loss, d_loss, loss_summary_op], feed_dict=feed_dict)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' %
                                     (g_loss_np, d_loss_np))

            # Test
            if np.mod(adv_epoch, 300) == 0 or adv_epoch == nadv_steps - 1:
                generator_obj.generated_num = generator_obj.batch_size * 10
                json_object = generator_obj.generate_samples(sess,
                                                             oracle_loader,
                                                             dataset="train")
                write_json(json_file, json_object)
                json_object = generator_obj.generate_samples(
                    sess, oracle_loader, dataset="validation")
                write_json(json_file_validation, json_object)

                # take sentences from saved files
                sent = take_sentences_attribute(json_object)
                gen_sentences_summary.write_summary(sent,
                                                    adv_epoch + npre_epochs)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str,
                                       adv_epoch + npre_epochs)
                # tqdm.write("in {} seconds".format(time.time() - t))

                msg = 'pre_gen_epoch:' + str(
                    adv_epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

                gc.collect()

        sum_writer.close()

        save_model = True
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={
                            "x_user": x_user,
                            "x_rating": x_rating,
                            "x_product": x_product
                        },
                        outputs={"gen_x": generator_obj.gen_x})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)
Exemplo n.º 7
0
def real_train(generator, discriminator, oracle_loader, config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'])
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))
    if dataset == 'image_coco':
        test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
    elif dataset == 'emnlp_news':
        data_file = os.path.join(data_dir, '{}_train.txt'.format(dataset))
        test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
    else:
        raise NotImplementedError('Unknown dataset!')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len],
                            name="x_real")  # tokens of oracle sequences

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    x_fake_onehot_appr, x_fake, g_pretrain_loss, gen_o = generator(
        x_real=x_real, temperature=temperature)
    d_out_real = discriminator(x_onehot=x_real_onehot)
    d_out_fake = discriminator(x_onehot=x_fake_onehot_appr)

    # GAN / Divergence type
    log_pg, g_loss, d_loss = get_losses(d_out_real, d_out_fake, x_real_onehot,
                                        x_fake_onehot_appr, gen_o,
                                        discriminator, config)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op = get_train_ops(
        config, g_pretrain_loss, g_loss, d_loss, log_pg, temperature,
        global_step)

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('loss/discriminator', d_loss),
        tf.summary.scalar('loss/g_loss', g_loss),
        tf.summary.scalar('loss/log_pg', log_pg),
        tf.summary.scalar('loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('loss/temperature', temperature),
    ]
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Placeholder for generator pretrain loss
    gen_pretrain_loss_placeholder = tf.placeholder(
        tf.float32, name='pretrain_loss_placeholder')
    gen_pretrain_loss = tf.summary.scalar('generator/pretrain_loss',
                                          gen_pretrain_loss_placeholder)

    # Placeholder for generated text
    gen_sentences_placeholder = tf.placeholder(
        tf.string, name='generated_sentences_placeholder')
    gen_sentences = tf.summary.text('generator/generated_sentences',
                                    gen_sentences_placeholder)

    saver = tf.train.Saver()

    # ------------- initial the graph --------------
    with init_sess() as sess:
        print("Inizia sessione")
        restore_model = False
        log = open(csv_file, 'w')

        sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'),
                                           sess.graph)

        # generate oracle data and create batches
        index_word_dict = get_oracle_file(data_file, oracle_file, seq_len)

        print("Inizio creo batch")
        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, test_file, gen_text_file,
                              g_pretrain_loss, x_real, sess)

        print('Start pre-training...')
        for epoch in range(npre_epochs):
            print("Pretrain epoch: {}".format(epoch))
            # pre-training
            g_pretrain_loss_np = pre_train_epoch(sess, g_pretrain_op,
                                                 g_pretrain_loss, x_real,
                                                 oracle_loader)
            pretrain_summary = sess.run(
                gen_pretrain_loss,
                feed_dict={gen_pretrain_loss_placeholder: g_pretrain_loss_np})
            sum_writer.add_summary(pretrain_summary, epoch)

            # Test
            ntest_pre = 10
            if np.mod(epoch, ntest_pre) == 0:
                # generate fake data and create batches
                gen_save_file = os.path.join(
                    sample_dir, 'pre_samples_{:05d}.txt'.format(epoch))
                generate_samples(sess, x_fake, batch_size, num_sentences,
                                 gen_file)
                get_real_test_file(gen_file, gen_save_file,
                                   index_word_dict)  # qua salvo ogni volta
                get_real_test_file(gen_file, gen_text_file,
                                   index_word_dict)  # qua sovrascrivo l'ultima

                # take sentences from saved files
                sent = take_sentences(gen_text_file)
                sent = random.sample(sent, 5)  # pick just one sentence
                generated_strings_summary = sess.run(
                    gen_sentences, feed_dict={gen_sentences_placeholder: sent})
                sum_writer.add_summary(generated_strings_summary, epoch)

                # write summaries
                print("Computing Metrics and writing summaries", end=" ")
                t = time.time()
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str, epoch)
                print("in {} seconds".format(time.time() - t))

                msg = 'pre_gen_epoch:' + str(
                    epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                print(msg)
                log.write(msg)
                log.write('\n')

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for adv_epoch in progress:
            niter = sess.run(global_step)
            # print("Adv_epoch: {}".format(adv_epoch))

            t0 = time.time()

            # adversarial training
            for _ in range(config['gsteps']):
                sess.run(g_train_op,
                         feed_dict={x_real: oracle_loader.random_batch()})
            for _ in range(config['dsteps']):
                sess.run(d_train_op,
                         feed_dict={x_real: oracle_loader.random_batch()})

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps,
                                                adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            feed = {x_real: oracle_loader.random_batch()}
            g_loss_np, d_loss_np, loss_summary_str = sess.run(
                [g_loss, d_loss, loss_summary_op], feed_dict=feed)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' %
                                     (g_loss_np, d_loss_np))

            # Test
            # print("N_iter: {}, test every {} epochs".format(niter, config['ntest']))
            if np.mod(adv_epoch, 250) == 0:
                # generate fake data and create batches
                gen_save_file = os.path.join(
                    sample_dir, 'adv_samples_{:05d}.txt'.format(niter))
                generate_samples(sess, x_fake, batch_size, num_sentences,
                                 gen_file)
                get_real_test_file(gen_file, gen_save_file, index_word_dict)
                get_real_test_file(gen_file, gen_text_file, index_word_dict)

                # take sentences from saved files
                sent = take_sentences(gen_text_file)
                sent = random.sample(sent, 8)  # pick just one sentence
                generated_strings_summary = sess.run(
                    gen_sentences, feed_dict={gen_sentences_placeholder: sent})
                sum_writer.add_summary(generated_strings_summary,
                                       niter + npre_epochs)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str,
                                       niter + config['npre_epochs'])

                msg = 'adv_step: ' + str(niter)
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                print(msg)
                log.write(msg)
                log.write('\n')
Exemplo n.º 8
0
def real_topic_train(generator_obj, discriminator_obj, topic_discriminator_obj,
                     oracle_loader: RealDataTopicLoader, config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    n_topic_pre_epochs = config['n_topic_pre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'])
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    json_file = os.path.join(sample_dir, 'json_file.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))
    if dataset == 'image_coco':
        test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
    elif dataset == 'emnlp_news':
        test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
    else:
        raise NotImplementedError('Unknown dataset!')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len],
                            name="x_real")  # tokens of oracle sequences
    x_topic = tf.placeholder(tf.float32,
                             [batch_size, oracle_loader.vocab_size + 1],
                             name="x_topic")  # todo stessa cosa del +1
    x_topic_random = tf.placeholder(tf.float32,
                                    [batch_size, oracle_loader.vocab_size + 1],
                                    name="x_topic_random")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    generator = generator_obj(x_real=x_real,
                              temperature=temperature,
                              x_topic=x_topic)
    d_real = discriminator_obj(x_onehot=x_real_onehot)
    d_fake = discriminator_obj(x_onehot=generator.gen_x_onehot_adv)
    if not args.no_topic:
        d_topic_real_pos = topic_discriminator_obj(x_onehot=x_real_onehot,
                                                   x_topic=x_topic)
        d_topic_real_neg = topic_discriminator_obj(x_onehot=x_real_onehot,
                                                   x_topic=x_topic_random)
        d_topic_fake = topic_discriminator_obj(
            x_onehot=generator.gen_x_onehot_adv, x_topic=x_topic)
    else:
        d_topic_real_pos = None
        d_topic_real_neg = None
        d_topic_fake = None

    # GAN / Divergence type
    losses = get_losses(generator, d_real, d_fake, d_topic_real_pos,
                        d_topic_real_neg, d_topic_fake, config)
    if not args.no_topic:
        d_topic_loss = losses['d_topic_loss_real_pos'] + losses[
            'd_topic_loss_real_neg']  # only from real data for pretrain
        d_topic_accuracy = get_accuracy(d_topic_real_pos.logits,
                                        d_topic_real_neg.logits)
    else:
        d_topic_loss = None
        d_topic_accuracy = None

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops(
        config, generator.pretrain_loss, losses['g_loss'], losses['d_loss'],
        d_topic_loss, losses['log_pg'], temperature, global_step)
    generator.g_pretrain_op = g_pretrain_op

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('adv_loss/discriminator/classic/d_loss_real',
                          losses['d_loss_real']),
        tf.summary.scalar('adv_loss/discriminator/classic/d_loss_fake',
                          losses['d_loss_fake']),
        tf.summary.scalar('adv_loss/discriminator/total', losses['d_loss']),
        tf.summary.scalar('adv_loss/generator/g_sentence_loss',
                          losses['g_sentence_loss']),
        tf.summary.scalar('adv_loss/generator/total_g_loss', losses['g_loss']),
        tf.summary.scalar('adv_loss/log_pg', losses['log_pg']),
        tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('adv_loss/temperature', temperature),
    ]
    if not args.no_topic:
        loss_summaries += [
            tf.summary.scalar(
                'adv_loss/discriminator/topic_discriminator/d_topic_loss_real_pos',
                losses['d_topic_loss_real_pos']),
            tf.summary.scalar(
                'adv_loss/discriminator/topic_discriminator/d_topic_loss_fake',
                losses['d_topic_loss_fake']),
            tf.summary.scalar('adv_loss/generator/g_topic_loss',
                              losses['g_topic_loss'])
        ]

    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    topic_discr_pretrain_summary = CustomSummary(name='pretrain_loss',
                                                 scope='topic_discriminator')
    topic_discr_accuracy_summary = CustomSummary(name='pretrain_accuracy',
                                                 scope='topic_discriminator')
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary,
        topic_discr_pretrain_summary, topic_discr_accuracy_summary,
        run_information
    ]

    # To save the trained model
    saver = tf.compat.v1.train.Saver()

    # ------------- initial the graph --------------
    with init_sess() as sess:
        variables_dict = get_parameters_division()

        print("Total paramter number: {}".format(
            np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))
        log = open(csv_file, 'w')

        now = datetime.datetime.now()
        additional_text = now.strftime(
            "%Y-%m-%d_%H-%M") + "_" + args.summary_name
        summary_dir = os.path.join(log_dir, 'summary', additional_text)
        if not os.path.exists(summary_dir):
            os.makedirs(summary_dir)
        sum_writer = tf.compat.v1.summary.FileWriter(os.path.join(summary_dir),
                                                     sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)

        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        # generate oracle data and create batches
        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, test_file, gen_text_file,
                              generator.pretrain_loss, x_real, x_topic, sess,
                              json_file)

        gc.collect()

        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')
            progress = tqdm(range(npre_epochs), dynamic_ncols=True)
            for epoch in progress:
                # pre-training
                g_pretrain_loss_np = generator.pretrain_epoch(
                    sess, oracle_loader)
                gen_pretrain_loss_summary.write_summary(
                    g_pretrain_loss_np, epoch)

                # Test
                ntest_pre = 30
                if np.mod(epoch, ntest_pre) == 0:
                    json_object = generator.generate_samples_topic(
                        sess, oracle_loader, num_sentences)
                    write_json(json_file, json_object)

                    # take sentences from saved files
                    sent = generator.get_sentences(json_object)
                    gen_sentences_summary.write_summary(sent, epoch)

                    # write summaries
                    t = time.time()
                    scores = [metric.get_score() for metric in metrics]
                    metrics_summary_str = sess.run(metric_summary_op,
                                                   feed_dict=dict(
                                                       zip(metrics_pl,
                                                           scores)))
                    sum_writer.add_summary(metrics_summary_str, epoch)

                    msg = 'pre_gen_epoch:' + str(
                        epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                    metric_names = [metric.get_name() for metric in metrics]
                    for (name, score) in zip(metric_names, scores):
                        score = score * 1e5 if 'Earth' in name else score
                        msg += ', ' + name + ': %.4f' % score
                    progress.set_description(
                        msg + " in {:.2f} sec".format(time.time() - t))
                    log.write(msg)
                    log.write('\n')

                    gc.collect()

        if not args.no_topic:
            gc.collect()
            print('Start Topic Discriminator pre-training...')
            progress = tqdm(range(n_topic_pre_epochs))
            for epoch in progress:
                # pre-training and write loss
                # Pre-train the generator using MLE for one epoch
                supervised_g_losses = []
                supervised_accuracy = []
                oracle_loader.reset_pointer()

                for it in range(oracle_loader.num_batch):
                    text_batch, topic_batch = oracle_loader.next_batch(
                        only_text=False)
                    _, topic_loss, accuracy = sess.run(
                        [d_topic_pretrain_op, d_topic_loss, d_topic_accuracy],
                        feed_dict={
                            x_real: text_batch,
                            x_topic: topic_batch,
                            x_topic_random: oracle_loader.random_topic()
                        })
                    supervised_g_losses.append(topic_loss)
                    supervised_accuracy.append(accuracy)

                d_topic_pretrain_loss = np.mean(supervised_g_losses)
                accuracy_mean = np.mean(supervised_accuracy)
                topic_discr_pretrain_summary.write_summary(
                    d_topic_pretrain_loss, epoch)
                topic_discr_accuracy_summary.write_summary(
                    accuracy_mean, epoch)
                progress.set_description(
                    'topic_loss: %4.4f, accuracy: %4.4f' %
                    (d_topic_pretrain_loss, accuracy_mean))

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for adv_epoch in progress:
            gc.collect()
            niter = sess.run(global_step)

            t0 = time.time()
            # Adversarial training
            for _ in range(config['gsteps']):
                text_batch, topic_batch = oracle_loader.random_batch(
                    only_text=False)
                sess.run(g_train_op,
                         feed_dict={
                             x_real: text_batch,
                             x_topic: topic_batch
                         })
            for _ in range(config['dsteps']):
                # normal + topic discriminator together
                text_batch, topic_batch = oracle_loader.random_batch(
                    only_text=False)
                sess.run(d_train_op,
                         feed_dict={
                             x_real: text_batch,
                             x_topic: topic_batch,
                             x_topic_random: oracle_loader.random_topic()
                         })

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps,
                                                adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            text_batch, topic_batch = oracle_loader.random_batch(
                only_text=False)
            feed = {
                x_real: text_batch,
                x_topic: topic_batch,
                x_topic_random: oracle_loader.random_topic()
            }
            g_loss_np, d_loss_np, loss_summary_str = sess.run(
                [losses['g_loss'], losses['d_loss'], loss_summary_op],
                feed_dict=feed)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' %
                                     (g_loss_np, d_loss_np))

            # Test
            if np.mod(adv_epoch, 400) == 0 or adv_epoch == nadv_steps - 1:
                json_object = generator.generate_samples_topic(
                    sess, oracle_loader, num_sentences)
                write_json(json_file, json_object)

                # take sentences from saved files
                sent = generator.get_sentences(json_object)
                gen_sentences_summary.write_summary(sent, adv_epoch)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str,
                                       niter + config['npre_epochs'])

                msg = 'adv_step: ' + str(niter)
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

        sum_writer.close()

        save_model = False
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={"x_topic": x_topic},
                        outputs={"gen_x": generator.gen_x})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)
Exemplo n.º 9
0
def real_train_NoDiscr(generator_obj: rmc_att_topic.generator, oracle_loader,
                       config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocab_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    # noinspection PyUnusedLocal
    n_topic_pre_epochs = config['n_topic_pre_epochs']
    # noinspection PyUnusedLocal
    nadv_steps = config['nadv_steps']
    # noinspection PyUnusedLocal
    temper = config['temperature']
    # noinspection PyUnusedLocal
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'])
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset))
    # noinspection PyUnusedLocal
    gen_file = os.path.join(sample_dir, 'generator.txt')
    gen_text_file = os.path.join(sample_dir, 'generator_text.txt')
    # noinspection PyUnusedLocal
    gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt')
    json_file = os.path.join(sample_dir,
                             'json_file{}.txt'.format(args.json_file))
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')
    # noinspection PyUnusedLocal
    data_file = os.path.join(data_dir, '{}.txt'.format(dataset))
    if dataset == 'image_coco':
        test_file = os.path.join(data_dir, 'testdata/test_coco.txt')
    elif dataset == 'emnlp_news':
        test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt')
    else:
        raise NotImplementedError('Unknown dataset!')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = placeholder(tf.int32, [batch_size, seq_len],
                         name="x_real")  # tokens of oracle sequences

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_onehot = tf.one_hot(x_real,
                               vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    generator = generator_obj(x_real=x_real, temperature=temperature)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    # noinspection PyUnusedLocal
    global_step_op = global_step.assign_add(1)

    # A function to calculate the gradients and get training operations
    def get_train_ops(config, g_pretrain_loss):
        gpre_lr = config['gpre_lr']

        g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES,
                                   scope='generator')
        grad_clip = 5.0  # keep the same with the previous setting

        # generator pre-training
        pretrain_opt = tf.train.AdamOptimizer(gpre_lr,
                                              beta1=0.9,
                                              beta2=0.999,
                                              name="gen_pretrain_adam")
        pretrain_grad, _ = tf.clip_by_global_norm(
            tf.gradients(g_pretrain_loss, g_vars, name="gradients_g_pretrain"),
            grad_clip,
            name="g_pretrain_clipping")  # gradient clipping
        g_pretrain_op = pretrain_opt.apply_gradients(zip(
            pretrain_grad, g_vars))

        return g_pretrain_op

    # Train ops
    # noinspection PyUnusedLocal
    g_pretrain_op = get_train_ops(config, generator.pretrain_loss)

    # Record wall clock time
    time_diff = placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    # noinspection PyUnusedLocal
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = placeholder(tf.float32)
    # noinspection PyUnusedLocal
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('adv_loss/temperature', temperature),
    ]
    # noinspection PyUnusedLocal
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary, run_information
    ]

    # To save the trained model
    # noinspection PyUnusedLocal
    saver = tf.train.Saver()

    # ------------- initial the graph --------------
    with init_sess() as sess:
        variables_dict = get_parameters_division()

        print("Total paramter number: {}".format(
            np.sum([
                np.prod(v.get_shape().as_list())
                for v in tf.trainable_variables()
            ])))
        log = open(csv_file, 'w')
        # file_suffix = "date: {}, normal RelGAN, pretrain epochs: {}, adv epochs: {}".format(datetime.datetime.now(),
        #                                                                                     npre_epochs, nadv_steps)
        sum_writer = tf.summary.FileWriter(os.path.join(
            log_dir, 'summary'), sess.graph)  # , filename_suffix=file_suffix)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)

        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")
        # generate oracle data and create batches
        # todo se le parole hanno poco senso potrebbe essere perchè qua ho la corrispondenza indice-parola sbagliata
        # noinspection PyUnusedLocal
        oracle_loader.create_batches(oracle_file)

        metrics = get_metrics(config, oracle_loader, test_file, gen_text_file,
                              generator.pretrain_loss, x_real, sess, json_file)

        gc.collect()
        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')

        progress = tqdm(range(npre_epochs))
        for epoch in progress:
            # pre-training
            g_pretrain_loss_np = generator.pre_train_epoch(
                sess, g_pretrain_op, oracle_loader)
            gen_pretrain_loss_summary.write_summary(g_pretrain_loss_np, epoch)
            progress.set_description(
                "Pretrain_loss: {}".format(g_pretrain_loss_np))
            # Test
            ntest_pre = 40
            if np.mod(epoch, ntest_pre) == 0:
                json_object = generator.generate_sentences(
                    sess,
                    batch_size,
                    num_sentences,
                    oracle_loader=oracle_loader)
                write_json(json_file, json_object)

                with open(gen_text_file, 'w') as outfile:
                    i = 0
                    for sent in json_object['sentences']:
                        if i < 200:
                            outfile.write(sent['generated_sentence'] + "\n")
                        else:
                            break

                # take sentences from saved files
                sent = take_sentences_json(json_object,
                                           first_elem='generated_sentence',
                                           second_elem=None)
                gen_sentences_summary.write_summary(sent, epoch)

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str, epoch)
                # tqdm.write("in {} seconds".format(time.time() - t))

                msg = 'pre_gen_epoch:' + str(
                    epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

                gc.collect()

            gc.collect()
        sum_writer.close()
Exemplo n.º 10
0
def customer_reviews_train(generator: ReviewGenerator,
                           discriminator_positive: ReviewDiscriminator,
                           discriminator_negative: ReviewDiscriminator,
                           oracle_loader: RealDataCustomerReviewsLoader,
                           config, args):
    batch_size = config['batch_size']
    num_sentences = config['num_sentences']
    vocab_size = config['vocabulary_size']
    seq_len = config['seq_len']
    dataset = config['dataset']
    npre_epochs = config['npre_epochs']
    nadv_steps = config['nadv_steps']
    temper = config['temperature']
    adapt = config['adapt']

    # changed to match resources path
    data_dir = resources_path(config['data_dir'], "Amazon_Attribute")
    log_dir = resources_path(config['log_dir'])
    sample_dir = resources_path(config['sample_dir'])

    # filename
    json_file = os.path.join(sample_dir, 'json_file.txt')
    csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv')

    # create necessary directories
    if not os.path.exists(data_dir):
        os.makedirs(data_dir)
    if not os.path.exists(sample_dir):
        os.makedirs(sample_dir)
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)

    # placeholder definitions
    x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real")
    x_pos = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_pos")
    x_neg = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_neg")
    x_sentiment = tf.placeholder(tf.int32, [batch_size], name="x_sentiment")

    temperature = tf.Variable(1., trainable=False, name='temperature')

    x_real_pos_onehot = tf.one_hot(
        x_pos, vocab_size)  # batch_size x seq_len x vocab_size
    x_real_neg_onehot = tf.one_hot(
        x_neg, vocab_size)  # batch_size x seq_len x vocab_size
    assert x_real_pos_onehot.get_shape().as_list() == [
        batch_size, seq_len, vocab_size
    ]

    # generator and discriminator outputs
    generator_obj = generator(x_real=x_real,
                              temperature=temperature,
                              x_sentiment=x_sentiment)
    # discriminator for positive sentences
    discriminator_positive_real_pos = discriminator_positive(
        x_onehot=x_real_pos_onehot)
    discriminator_positive_real_neg = discriminator_positive(
        x_onehot=x_real_neg_onehot)
    discriminator_positive_fake = discriminator_positive(
        x_onehot=generator_obj.gen_x_onehot_adv)
    # discriminator for negative sentences
    discriminator_negative_real_pos = discriminator_negative(
        x_onehot=x_real_pos_onehot)
    discriminator_negative_real_neg = discriminator_negative(
        x_onehot=x_real_neg_onehot)
    discriminator_negative_fake = discriminator_negative(
        x_onehot=generator_obj.gen_x_onehot_adv)

    # GAN / Divergence type

    log_pg, g_loss, d_loss = get_losses(
        generator_obj, discriminator_positive_real_pos,
        discriminator_positive_real_neg, discriminator_positive_fake,
        discriminator_negative_real_pos, discriminator_negative_real_neg,
        discriminator_negative_fake)

    # Global step
    global_step = tf.Variable(0, trainable=False)
    global_step_op = global_step.assign_add(1)

    # Train ops
    g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops(
        config, generator_obj.pretrain_loss, g_loss, d_loss, None, log_pg,
        temperature, global_step)

    # Record wall clock time
    time_diff = tf.placeholder(tf.float32)
    Wall_clock_time = tf.Variable(0., trainable=False)
    update_Wall_op = Wall_clock_time.assign_add(time_diff)

    # Temperature placeholder
    temp_var = tf.placeholder(tf.float32)
    update_temperature_op = temperature.assign(temp_var)

    # Loss summaries
    loss_summaries = [
        tf.summary.scalar('adv_loss/discriminator/total', d_loss),
        tf.summary.scalar('adv_loss/generator/total_g_loss', g_loss),
        tf.summary.scalar('adv_loss/log_pg', log_pg),
        tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time),
        tf.summary.scalar('adv_loss/temperature', temperature),
    ]
    loss_summary_op = tf.summary.merge(loss_summaries)

    # Metric Summaries
    metrics_pl, metric_summary_op = get_metric_summary_op(config)

    # Summaries
    gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss',
                                              scope='generator')
    gen_sentences_summary = CustomSummary(name='generated_sentences',
                                          scope='generator',
                                          summary_type=tf.summary.text,
                                          item_type=tf.string)
    run_information = CustomSummary(name='run_information',
                                    scope='info',
                                    summary_type=tf.summary.text,
                                    item_type=tf.string)
    custom_summaries = [
        gen_pretrain_loss_summary, gen_sentences_summary, run_information
    ]

    # To save the trained model
    # ------------- initial the graph --------------
    with init_sess() as sess:

        # count parameters

        log = open(csv_file, 'w')
        summary_dir = os.path.join(log_dir, 'summary', str(time.time()))
        if not os.path.exists(summary_dir):
            os.makedirs(summary_dir)
        sum_writer = tf.summary.FileWriter(summary_dir, sess.graph)
        for custom_summary in custom_summaries:
            custom_summary.set_file_writer(sum_writer, sess)

        run_information.write_summary(str(args), 0)
        print("Information stored in the summary!")

        def get_metrics():
            # set up evaluation metric
            metrics = []
            if config['nll_gen']:
                nll_gen = NllReview(oracle_loader,
                                    generator_obj,
                                    sess,
                                    name='nll_gen_review')
                metrics.append(nll_gen)
            if config['KL']:
                KL_div = KL_divergence(oracle_loader,
                                       json_file,
                                       name='KL_divergence')
                metrics.append(KL_div)
            if config['jaccard_similarity']:
                Jaccard_Sim = JaccardSimilarity(oracle_loader,
                                                json_file,
                                                name='jaccard_similarity')
                metrics.append(Jaccard_Sim)
            if config['jaccard_diversity']:
                Jaccard_Sim = JaccardDiversity(oracle_loader,
                                               json_file,
                                               name='jaccard_diversity')
                metrics.append(Jaccard_Sim)

            return metrics

        metrics = get_metrics()
        generator_obj.generated_num = 200  #num_sentences

        gc.collect()
        # Check if there is a pretrained generator saved
        model_dir = "PretrainGenerator"
        model_path = resources_path(
            os.path.join("checkpoint_folder", model_dir))
        try:
            new_saver = tf.train.import_meta_graph(
                os.path.join(model_path, "model.ckpt.meta"))
            new_saver.restore(sess, os.path.join(model_path, "model.ckpt"))
            print("Used saved model for generator pretrain")
        except OSError:
            print('Start pre-training...')
            # pre-training
            # Pre-train the generator using MLE for one epoch

            progress = tqdm(range(npre_epochs))
            for epoch in progress:
                oracle_loader.reset_pointer()
                g_pretrain_loss_np = generator_obj.pretrain_epoch(
                    oracle_loader, sess, g_pretrain_op=g_pretrain_op)
                gen_pretrain_loss_summary.write_summary(
                    g_pretrain_loss_np, epoch)
                msg = 'pre_gen_epoch:' + str(
                    epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                progress.set_description(msg)

                # Test
                ntest_pre = 30
                if np.mod(epoch, ntest_pre) == 0 or epoch == npre_epochs - 1:
                    json_object = generator_obj.generate_json(
                        oracle_loader, sess)
                    write_json(json_file, json_object)

                    # take sentences from saved files
                    sent = take_sentences_json(json_object)
                    gen_sentences_summary.write_summary(sent, epoch)

                    # write summaries
                    scores = [metric.get_score() for metric in metrics]
                    metrics_summary_str = sess.run(metric_summary_op,
                                                   feed_dict=dict(
                                                       zip(metrics_pl,
                                                           scores)))
                    sum_writer.add_summary(metrics_summary_str, epoch)

                    msg = 'pre_gen_epoch:' + str(
                        epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np
                    metric_names = [metric.get_name() for metric in metrics]
                    for (name, score) in zip(metric_names, scores):
                        msg += ', ' + name + ': %.4f' % score
                    tqdm.write(msg)
                    log.write(msg)
                    log.write('\n')

                    gc.collect()

        gc.collect()

        print('Start adversarial training...')
        progress = tqdm(range(nadv_steps))
        for adv_epoch in progress:
            gc.collect()
            niter = sess.run(global_step)

            t0 = time.time()
            # Adversarial training
            for _ in range(config['gsteps']):
                sentiment, sentence = oracle_loader.random_batch()
                n = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
                for ind, el in enumerate(sentence):
                    n[ind] = el
                sess.run(g_pretrain_op,
                         feed_dict={
                             generator_obj.x_real: n,
                             generator_obj.x_sentiment: sentiment
                         })
            for _ in range(config['dsteps']):
                sentiment, sentence, pos, neg = oracle_loader.get_positive_negative_batch(
                )
                n1 = np.zeros(
                    (generator_obj.batch_size, generator_obj.seq_len))
                n2 = np.zeros(
                    (generator_obj.batch_size, generator_obj.seq_len))
                n3 = np.zeros(
                    (generator_obj.batch_size, generator_obj.seq_len))
                for ind, (s, p, n) in enumerate(zip(sentence, pos, neg)):
                    n1[ind] = s
                    n2[ind] = p[0]
                    n3[ind] = n[0]
                feed_dict = {
                    x_real: n1,
                    x_pos: n2,
                    x_neg: n3,
                    x_sentiment: sentiment
                }
                sess.run(d_train_op, feed_dict=feed_dict)

            t1 = time.time()
            sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0})

            # temperature
            temp_var_np = get_fixed_temperature(temper, niter, nadv_steps,
                                                adapt)
            sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np})

            sentiment, sentence, pos, neg = oracle_loader.get_positive_negative_batch(
            )
            n1 = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
            n2 = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
            n3 = np.zeros((generator_obj.batch_size, generator_obj.seq_len))
            for ind, (s, p, n) in enumerate(zip(sentence, pos, neg)):
                n1[ind] = s
                n2[ind] = p[0]
                n3[ind] = n[0]
            feed_dict = {
                x_real: n1,
                x_pos: n2,
                x_neg: n3,
                x_sentiment: sentiment
            }
            g_loss_np, d_loss_np, loss_summary_str = sess.run(
                [g_loss, d_loss, loss_summary_op], feed_dict=feed_dict)
            sum_writer.add_summary(loss_summary_str, niter)

            sess.run(global_step_op)

            progress.set_description('g_loss: %4.4f, d_loss: %4.4f' %
                                     (g_loss_np, d_loss_np))

            # Test
            # print("N_iter: {}, test every {} epochs".format(niter, config['ntest']))
            if np.mod(adv_epoch, 100) == 0 or adv_epoch == nadv_steps - 1:
                json_object = generator_obj.generate_json(oracle_loader, sess)
                write_json(json_file, json_object)

                # take sentences from saved files
                sent = take_sentences_json(json_object)
                gen_sentences_summary.write_summary(
                    sent, niter + config['npre_epochs'])

                # write summaries
                scores = [metric.get_score() for metric in metrics]
                metrics_summary_str = sess.run(metric_summary_op,
                                               feed_dict=dict(
                                                   zip(metrics_pl, scores)))
                sum_writer.add_summary(metrics_summary_str,
                                       niter + config['npre_epochs'])

                msg = 'adv_step: ' + str(niter)
                metric_names = [metric.get_name() for metric in metrics]
                for (name, score) in zip(metric_names, scores):
                    msg += ', ' + name + ': %.4f' % score
                tqdm.write(msg)
                log.write(msg)
                log.write('\n')

                gc.collect()

        sum_writer.close()

        save_model = False
        if save_model:
            model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S")
            model_path = os.path.join(resources_path("trained_models"),
                                      model_dir)
            simple_save(sess,
                        model_path,
                        inputs={"x_topic": x_topic},
                        outputs={"gen_x": x_fake})
            # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt"))
            print("Model saved in path: %s" % model_path)