def main(): args = parser.parse_args() pp.pprint(vars(args)) config = vars(args) model_path = resources_path(os.path.join('trained_models', config['model_name'])) input_path = resources_path(os.path.join('inference_data', config['input_name'])) data_file = resources_path(args.data_dir, '{}.txt'.format(args.dataset)) sample_dir = resources_path(config['sample_dir']) oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(args.dataset)) if args.dataset == 'emnlp_news' : data_file, lda_file = create_subsample_data_file(data_file) else: lda_file = data_file seq_len, vocab_size, word_index_dict, index_word_dict = text_precess(data_file, oracle_file=oracle_file) print(index_word_dict) config['seq_len'] = seq_len config['vocab_size'] = vocab_size print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size)) if config['topic']: topic_number = config['topic_number'] oracle_loader = RealDataTopicLoader(args.batch_size, args.seq_len) oracle_loader.set_dataset(args.dataset) oracle_loader.topic_num = topic_number oracle_loader.set_dictionaries(word_index_dict, index_word_dict) oracle_loader.get_LDA(word_index_dict, index_word_dict, data_file) print(oracle_loader.model_index_word_dict) inference_main(oracle_loader, config, model_path, input_path)
def get_corpus(coco=True, datapath=None) -> List[str]: if datapath is None: if coco: fname = resources_path("data", "image_coco.txt") else: fname = resources_path("data", "emnlp_news.txt") else: fname = datapath with open(fname) as f: lines = [line.rstrip('\n') for line in f] return lines
def get_perc_sent_topic(lda, topic_num, data_file): file_path = "{}-{}-{}".format(lda.lda_model, topic_num, data_file[-10:]) file_path = resources_path("topic_models", file_path) try: sent_topics_df = load_pickle(file_path) except FileNotFoundError: print("get perc sent topic not found") ldamodel = lda.lda_model with open(data_file) as f: sentences = [line.rstrip('\n') for line in f] tmp = process_texts(sentences, lda.stops) corpus_bow = [lda.dictionary.doc2bow(i) for i in tmp] # Init output sent_topics_df = pd.DataFrame() # Get main topic in each document for i, row_list in enumerate(tqdm(ldamodel[corpus_bow])): row = row_list[0] if ldamodel.per_word_topics else row_list # print(row) # row = sorted(row, key=lambda x: (x[1]), reverse=True) # sort list to get dominant topic # Get the Dominant topic, Perc Contribution and Keywords for each document to_append = np.zeros(topic_num) for j, (topic_n, prop_topic) in enumerate(row): to_append[topic_n] = prop_topic sent_topics_df = sent_topics_df.append(pd.Series(to_append), ignore_index=True) sent_topics_df.columns = [ "Topic {}".format(topic_number) for topic_number in range(len(to_append)) ] # Add original text to the end of the output contents = pd.Series(sentences) sent_topics_df = pd.concat([sent_topics_df, contents], axis=1) sent_topics_df = sent_topics_df.reset_index() write_pickle(file_path, sent_topics_df) return sent_topics_df
def main(): args = parser.parse_args() pp.pprint(vars(args)) config = vars(args) # train with different datasets if args.dataset == 'oracle': oracle_model = OracleLstm(num_vocabulary=args.vocab_size, batch_size=args.batch_size, emb_dim=args.gen_emb_dim, hidden_dim=args.hidden_dim, sequence_length=args.seq_len, start_token=args.start_token) oracle_loader = OracleDataLoader(args.batch_size, args.seq_len) gen_loader = OracleDataLoader(args.batch_size, args.seq_len) generator = models.get_generator(args.g_architecture, vocab_size=args.vocab_size, batch_size=args.batch_size, seq_len=args.seq_len, gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots, head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim, start_token=args.start_token) discriminator = models.get_discriminator(args.d_architecture, batch_size=args.batch_size, seq_len=args.seq_len, vocab_size=args.vocab_size, dis_emb_dim=args.dis_emb_dim, num_rep=args.num_rep, sn=args.sn) oracle_train(generator, discriminator, oracle_model, oracle_loader, gen_loader, config) elif args.dataset in ['image_coco', 'emnlp_news']: # custom dataset selected data_file = resources_path(args.data_dir, '{}.txt'.format(args.dataset)) sample_dir = resources_path(config['sample_dir']) oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(args.dataset)) data_dir = resources_path(config['data_dir']) if args.dataset == 'image_coco': test_file = os.path.join(data_dir, 'testdata/test_coco.txt') elif args.dataset == 'emnlp_news': test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt') else: raise NotImplementedError('Unknown dataset!') if args.dataset == 'emnlp_news': data_file, lda_file = create_subsample_data_file(data_file) else: lda_file = data_file seq_len, vocab_size, word_index_dict, index_word_dict = text_precess( data_file, test_file, oracle_file=oracle_file) config['seq_len'] = seq_len config['vocab_size'] = vocab_size print('seq_len: %d, vocab_size: %d' % (seq_len, vocab_size)) config['topic_loss_weight'] = args.topic_loss_weight if config['LSTM']: if config['topic']: topic_number = config['topic_number'] oracle_loader = RealDataTopicLoader(args.batch_size, args.seq_len) oracle_loader.set_dataset(args.dataset) oracle_loader.set_files(data_file, lda_file) oracle_loader.topic_num = topic_number oracle_loader.set_dictionaries(word_index_dict, index_word_dict) generator = models.get_generator( args.g_architecture, vocab_size=vocab_size, batch_size=args.batch_size, seq_len=seq_len, gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots, head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim, start_token=args.start_token, TopicInMemory=args.topic_in_memory, NoTopic=args.no_topic) from real.real_gan.real_topic_train_NoDiscr import real_topic_train_NoDiscr real_topic_train_NoDiscr(generator, oracle_loader, config, args) else: generator = models.get_generator(args.g_architecture, vocab_size=vocab_size, batch_size=args.batch_size, seq_len=seq_len, gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots, head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim, start_token=args.start_token) oracle_loader = RealDataLoader(args.batch_size, args.seq_len) oracle_loader.set_dictionaries(word_index_dict, index_word_dict) oracle_loader.set_dataset(args.dataset) oracle_loader.set_files(data_file, lda_file) oracle_loader.topic_num = config['topic_number'] from real.real_gan.real_train_NoDiscr import real_train_NoDiscr real_train_NoDiscr(generator, oracle_loader, config, args) else: if config['topic']: topic_number = config['topic_number'] oracle_loader = RealDataTopicLoader(args.batch_size, args.seq_len) oracle_loader.set_dataset(args.dataset) oracle_loader.set_files(data_file, lda_file) oracle_loader.topic_num = topic_number oracle_loader.set_dictionaries(word_index_dict, index_word_dict) generator = models.get_generator( args.g_architecture, vocab_size=vocab_size, batch_size=args.batch_size, seq_len=seq_len, gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots, head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim, start_token=args.start_token, TopicInMemory=args.topic_in_memory, NoTopic=args.no_topic) discriminator = models.get_discriminator( args.d_architecture, batch_size=args.batch_size, seq_len=seq_len, vocab_size=vocab_size, dis_emb_dim=args.dis_emb_dim, num_rep=args.num_rep, sn=args.sn) if not args.no_topic: topic_discriminator = models.get_topic_discriminator( args.topic_architecture, batch_size=args.batch_size, seq_len=seq_len, vocab_size=vocab_size, dis_emb_dim=args.dis_emb_dim, num_rep=args.num_rep, sn=args.sn, discriminator=discriminator) else: topic_discriminator = None from real.real_gan.real_topic_train import real_topic_train real_topic_train(generator, discriminator, topic_discriminator, oracle_loader, config, args) else: generator = models.get_generator(args.g_architecture, vocab_size=vocab_size, batch_size=args.batch_size, seq_len=seq_len, gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots, head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim, start_token=args.start_token) discriminator = models.get_discriminator( args.d_architecture, batch_size=args.batch_size, seq_len=seq_len, vocab_size=vocab_size, dis_emb_dim=args.dis_emb_dim, num_rep=args.num_rep, sn=args.sn) oracle_loader = RealDataLoader(args.batch_size, args.seq_len) from real.real_gan.real_train import real_train real_train(generator, discriminator, oracle_loader, config, args) elif args.dataset in ['Amazon_Attribute']: # custom dataset selected data_dir = resources_path(config['data_dir'], "Amazon_Attribute") sample_dir = resources_path(config['sample_dir']) oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(args.dataset)) train_file = os.path.join(data_dir, 'train.csv') dev_file = os.path.join(data_dir, 'dev.csv') test_file = os.path.join(data_dir, 'test.csv') # create_tokens_files(data_files=[train_file, dev_file, test_file]) config_file = load_json(os.path.join(data_dir, 'config.json')) config = {**config, **config_file} # merge dictionaries from real.real_gan.loaders.amazon_loader import RealDataAmazonLoader oracle_loader = RealDataAmazonLoader(args.batch_size, args.seq_len) oracle_loader.create_batches( data_file=[train_file, dev_file, test_file]) oracle_loader.model_index_word_dict = load_json( join(data_dir, 'index_word_dict.json')) oracle_loader.model_word_index_dict = load_json( join(data_dir, 'word_index_dict.json')) generator = models.get_generator("amazon_attribute", vocab_size=config['vocabulary_size'], batch_size=args.batch_size, seq_len=config['seq_len'], gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots, head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim, start_token=args.start_token, user_num=config['user_num'], product_num=config['product_num'], rating_num=5) discriminator = models.get_discriminator( "amazon_attribute", batch_size=args.batch_size, seq_len=config['seq_len'], vocab_size=config['vocabulary_size'], dis_emb_dim=args.dis_emb_dim, num_rep=args.num_rep, sn=args.sn) from real.real_gan.amazon_attribute_train import amazon_attribute_train amazon_attribute_train(generator, discriminator, oracle_loader, config, args) elif args.dataset in ['CustomerReviews', 'imdb']: from real.real_gan.loaders.custom_reviews_loader import RealDataCustomerReviewsLoader from real.real_gan.customer_reviews_train import customer_reviews_train # custom dataset selected if args.dataset == 'CustomerReviews': data_dir = resources_path(config['data_dir'], "MovieReviews", "cr") elif args.dataset == 'imdb': data_dir = resources_path(config['data_dir'], "MovieReviews", 'movie', 'sstb') else: raise ValueError sample_dir = resources_path(config['sample_dir']) oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(args.dataset)) train_file = os.path.join(data_dir, 'train.csv') # create_tokens_files(data_files=[train_file, dev_file, test_file]) config_file = load_json(os.path.join(data_dir, 'config.json')) config = {**config, **config_file} # merge dictionaries oracle_loader = RealDataCustomerReviewsLoader(args.batch_size, args.seq_len) oracle_loader.create_batches(data_file=[train_file]) oracle_loader.model_index_word_dict = load_json( join(data_dir, 'index_word_dict.json')) oracle_loader.model_word_index_dict = load_json( join(data_dir, 'word_index_dict.json')) generator = models.get_generator("CustomerReviews", vocab_size=config['vocabulary_size'], batch_size=args.batch_size, start_token=args.start_token, seq_len=config['seq_len'], gen_emb_dim=args.gen_emb_dim, mem_slots=args.mem_slots, head_size=args.head_size, num_heads=args.num_heads, hidden_dim=args.hidden_dim, sentiment_num=config['sentiment_num']) discriminator_positive = models.get_discriminator( "CustomerReviews", scope="discriminator_positive", batch_size=args.batch_size, seq_len=config['seq_len'], vocab_size=config['vocabulary_size'], dis_emb_dim=args.dis_emb_dim, num_rep=args.num_rep, sn=args.sn) discriminator_negative = models.get_discriminator( "CustomerReviews", scope="discriminator_negative", batch_size=args.batch_size, seq_len=config['seq_len'], vocab_size=config['vocabulary_size'], dis_emb_dim=args.dis_emb_dim, num_rep=args.num_rep, sn=args.sn) customer_reviews_train(generator, discriminator_positive, discriminator_negative, oracle_loader, config, args) else: raise NotImplementedError('{}: unknown dataset!'.format(args.dataset)) print("RUN FINISHED") return
def real_topic_train_NoDiscr(generator, oracle_loader: RealDataTopicLoader, config, args): batch_size = config['batch_size'] num_sentences = config['num_sentences'] vocab_size = config['vocab_size'] seq_len = config['seq_len'] dataset = config['dataset'] npre_epochs = config['npre_epochs'] nadv_steps = config['nadv_steps'] temper = config['temperature'] adapt = config['adapt'] # changed to match resources path data_dir = resources_path(config['data_dir']) log_dir = resources_path(config['log_dir']) sample_dir = resources_path(config['sample_dir']) # filename oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset)) gen_file = os.path.join(sample_dir, 'generator.txt') gen_text_file = os.path.join(sample_dir, 'generator_text.txt') gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt') json_file = os.path.join(sample_dir, 'json_file{}.txt'.format(args.json_file)) csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv') data_file = os.path.join(data_dir, '{}.txt'.format(dataset)) if dataset == 'image_coco': test_file = os.path.join(data_dir, 'testdata/test_coco.txt') elif dataset == 'emnlp_news': test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt') else: raise NotImplementedError('Unknown dataset!') # create necessary directories if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # placeholder definitions x_real = placeholder(tf.int32, [batch_size, seq_len], name="x_real") # tokens of oracle sequences x_topic = placeholder(tf.float32, [batch_size, oracle_loader.vocab_size + 1], name="x_topic") # todo stessa cosa del +1 x_topic_random = placeholder(tf.float32, [batch_size, oracle_loader.vocab_size + 1], name="x_topic_random") temperature = tf.Variable(1., trainable=False, name='temperature') x_real_onehot = tf.one_hot(x_real, vocab_size) # batch_size x seq_len x vocab_size assert x_real_onehot.get_shape().as_list() == [ batch_size, seq_len, vocab_size ] # generator and discriminator outputs x_fake_onehot_appr, x_fake, g_pretrain_loss, gen_o, \ lambda_values_returned, gen_x_no_lambda = generator(x_real=x_real, temperature=temperature, x_topic=x_topic) # A function to calculate the gradients and get training operations def get_train_ops(config, g_pretrain_loss): gpre_lr = config['gpre_lr'] g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') grad_clip = 5.0 # keep the same with the previous setting # generator pre-training pretrain_opt = tf.train.AdamOptimizer(gpre_lr, beta1=0.9, beta2=0.999, name="gen_pretrain_adam") pretrain_grad, _ = tf.clip_by_global_norm( tf.gradients(g_pretrain_loss, g_vars, name="gradients_g_pretrain"), grad_clip, name="g_pretrain_clipping") # gradient clipping g_pretrain_op = pretrain_opt.apply_gradients(zip( pretrain_grad, g_vars)) return g_pretrain_op # Train ops g_pretrain_op = get_train_ops(config, g_pretrain_loss) # Metric Summaries metrics_pl, metric_summary_op = get_metric_summary_op(config) # Summaries gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss', scope='generator') gen_sentences_summary = CustomSummary(name='generated_sentences', scope='generator', summary_type=tf.summary.text, item_type=tf.string) run_information = CustomSummary(name='run_information', scope='info', summary_type=tf.summary.text, item_type=tf.string) custom_summaries = [ gen_pretrain_loss_summary, gen_sentences_summary, run_information ] # ------------- initial the graph -------------- with init_sess() as sess: variables_dict = get_parameters_division() log = open(csv_file, 'w') sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'), sess.graph) for custom_summary in custom_summaries: custom_summary.set_file_writer(sum_writer, sess) run_information.write_summary(str(args), 0) print("Information stored in the summary!") oracle_loader.create_batches(oracle_file) metrics = get_metrics(config, oracle_loader, test_file, gen_text_file, g_pretrain_loss, x_real, x_topic, sess, json_file) gc.collect() # Check if there is a pretrained generator saved model_dir = "PretrainGenerator" model_path = resources_path( os.path.join("checkpoint_folder", model_dir)) try: new_saver = tf.train.import_meta_graph( os.path.join(model_path, "model.ckpt.meta")) new_saver.restore(sess, os.path.join(model_path, "model.ckpt")) print("Used saved model for generator pretrain") except OSError: print('Start pre-training...') progress = tqdm(range(npre_epochs)) for epoch in progress: # pre-training g_pretrain_loss_np = pre_train_epoch(sess, g_pretrain_op, g_pretrain_loss, x_real, oracle_loader, x_topic) gen_pretrain_loss_summary.write_summary(g_pretrain_loss_np, epoch) progress.set_description( "Pretrain_loss: {}".format(g_pretrain_loss_np)) # Test ntest_pre = 40 if np.mod(epoch, ntest_pre) == 0: json_object = generate_sentences(sess, x_fake, batch_size, num_sentences, oracle_loader=oracle_loader, x_topic=x_topic) write_json(json_file, json_object) with open(gen_text_file, 'w') as outfile: i = 0 for sent in json_object['sentences']: if i < 200: outfile.write(sent['generated_sentence'] + "\n") else: break # take sentences from saved files sent = take_sentences_json(json_object, first_elem='generated_sentence', second_elem='real_starting') gen_sentences_summary.write_summary(sent, epoch) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, epoch) msg = 'pre_gen_epoch:' + str( epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score tqdm.write(msg) log.write(msg) log.write('\n') gc.collect() gc.collect() sum_writer.close() save_model = True if save_model: model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S") model_path = os.path.join(resources_path("trained_models"), model_dir) simple_save(sess, model_path, inputs={"x_topic": x_topic}, outputs={"gen_x": x_fake}) # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt")) print("Model saved in path: %s" % model_path)
def amazon_attribute_train(generator: rmc_att_topic.generator, discriminator: rmc_att_topic.discriminator, oracle_loader: RealDataAmazonLoader, config, args): batch_size = config['batch_size'] num_sentences = config['num_sentences'] vocab_size = config['vocabulary_size'] seq_len = config['seq_len'] dataset = config['dataset'] npre_epochs = config['npre_epochs'] n_topic_pre_epochs = config['n_topic_pre_epochs'] nadv_steps = config['nadv_steps'] temper = config['temperature'] adapt = config['adapt'] # changed to match resources path data_dir = resources_path(config['data_dir'], "Amazon_Attribute") log_dir = resources_path(config['log_dir']) sample_dir = resources_path(config['sample_dir']) # filename oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset)) gen_file = os.path.join(sample_dir, 'generator.txt') gen_text_file = os.path.join(sample_dir, 'generator_text.txt') gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt') json_file = os.path.join(sample_dir, 'json_file.txt') json_file_validation = os.path.join(sample_dir, 'json_file_validation.txt') csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv') data_file = os.path.join(data_dir, '{}.txt'.format(dataset)) test_file = os.path.join(data_dir, 'test.csv') # create necessary directories if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # placeholder definitions x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real") # tokens of oracle sequences x_user = tf.placeholder(tf.int32, [batch_size], name="x_user") x_product = tf.placeholder(tf.int32, [batch_size], name="x_product") x_rating = tf.placeholder(tf.int32, [batch_size], name="x_rating") temperature = tf.Variable(1., trainable=False, name='temperature') x_real_onehot = tf.one_hot(x_real, vocab_size) # batch_size x seq_len x vocab_size assert x_real_onehot.get_shape().as_list() == [ batch_size, seq_len, vocab_size ] # generator and discriminator outputs generator_obj = generator(x_real=x_real, temperature=temperature, x_user=x_user, x_product=x_product, x_rating=x_rating) discriminator_real = discriminator( x_onehot=x_real_onehot) # , with_out=False) discriminator_fake = discriminator( x_onehot=generator_obj.gen_x_onehot_adv) # , with_out=False) # GAN / Divergence type log_pg, g_loss, d_loss = get_losses(generator_obj, discriminator_real, discriminator_fake, config) # Global step global_step = tf.Variable(0, trainable=False) global_step_op = global_step.assign_add(1) # Train ops g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops( config, generator_obj.pretrain_loss, g_loss, d_loss, None, log_pg, temperature, global_step) # Record wall clock time time_diff = tf.placeholder(tf.float32) Wall_clock_time = tf.Variable(0., trainable=False) update_Wall_op = Wall_clock_time.assign_add(time_diff) # Temperature placeholder temp_var = tf.placeholder(tf.float32) update_temperature_op = temperature.assign(temp_var) # Loss summaries loss_summaries = [ tf.summary.scalar('adv_loss/discriminator/total', d_loss), tf.summary.scalar('adv_loss/generator/total_g_loss', g_loss), tf.summary.scalar('adv_loss/log_pg', log_pg), tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time), tf.summary.scalar('adv_loss/temperature', temperature), ] loss_summary_op = tf.summary.merge(loss_summaries) # Metric Summaries config['bleu_amazon'] = True config['bleu_amazon_validation'] = True metrics_pl, metric_summary_op = get_metric_summary_op(config) # Summaries gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss', scope='generator') gen_sentences_summary = CustomSummary(name='generated_sentences', scope='generator', summary_type=tf.summary.text, item_type=tf.string) topic_discr_pretrain_summary = CustomSummary(name='pretrain_loss', scope='topic_discriminator') topic_discr_accuracy_summary = CustomSummary(name='pretrain_accuracy', scope='topic_discriminator') run_information = CustomSummary(name='run_information', scope='info', summary_type=tf.summary.text, item_type=tf.string) custom_summaries = [ gen_pretrain_loss_summary, gen_sentences_summary, topic_discr_pretrain_summary, topic_discr_accuracy_summary, run_information ] # To save the trained model saver = tf.train.Saver() # ------------- initial the graph -------------- with init_sess() as sess: variables_dict = get_parameters_division() log = open(csv_file, 'w') sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'), sess.graph) for custom_summary in custom_summaries: custom_summary.set_file_writer(sum_writer, sess) run_information.write_summary(str(args), 0) print("Information stored in the summary!") metrics = get_metrics(config, oracle_loader, sess, json_file, json_file_validation, generator_obj) gc.collect() # Check if there is a pretrained generator saved model_dir = "PretrainGenerator" model_path = resources_path( os.path.join("checkpoint_folder", model_dir)) try: new_saver = tf.train.import_meta_graph( os.path.join(model_path, "model.ckpt.meta")) new_saver.restore(sess, os.path.join(model_path, "model.ckpt")) print("Used saved model for generator pretrain") except OSError: print('Start pre-training...') # pre-training # Pre-train the generator using MLE for one epoch progress = tqdm(range(npre_epochs)) for epoch in progress: g_pretrain_loss_np = generator_obj.pretrain_epoch( oracle_loader, sess, g_pretrain_op=g_pretrain_op) gen_pretrain_loss_summary.write_summary( g_pretrain_loss_np, epoch) # Test ntest_pre = 40 if np.mod(epoch, ntest_pre) == 0: generator_obj.generated_num = 200 json_object = generator_obj.generate_samples( sess, oracle_loader, dataset="train") write_json(json_file, json_object) json_object = generator_obj.generate_samples( sess, oracle_loader, dataset="validation") write_json(json_file_validation, json_object) # take sentences from saved files sent = generator.get_sentences(json_object) sent = take_sentences_attribute(json_object) gen_sentences_summary.write_summary(sent, epoch) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, epoch) msg = 'pre_gen_epoch:' + str( epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score tqdm.write(msg) log.write(msg) log.write('\n') gc.collect() print('Start adversarial training...') progress = tqdm(range(nadv_steps)) for adv_epoch in progress: gc.collect() niter = sess.run(global_step) t0 = time.time() # Adversarial training for _ in range(config['gsteps']): user, product, rating, sentence = oracle_loader.random_batch( dataset="train") feed_dict = { generator_obj.x_user: user, generator_obj.x_product: product, generator_obj.x_rating: rating } sess.run(g_train_op, feed_dict=feed_dict) for _ in range(config['dsteps']): user, product, rating, sentence = oracle_loader.random_batch( dataset="train") n = np.zeros((batch_size, seq_len)) for ind, el in enumerate(sentence): n[ind] = el feed_dict = { generator_obj.x_user: user, generator_obj.x_product: product, generator_obj.x_rating: rating, x_real: n } sess.run(d_train_op, feed_dict=feed_dict) t1 = time.time() sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0}) # temperature temp_var_np = get_fixed_temperature(temper, niter, nadv_steps, adapt) sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np}) user, product, rating, sentence = oracle_loader.random_batch( dataset="train") n = np.zeros((batch_size, seq_len)) for ind, el in enumerate(sentence): n[ind] = el feed_dict = { generator_obj.x_user: user, generator_obj.x_product: product, generator_obj.x_rating: rating, x_real: n } g_loss_np, d_loss_np, loss_summary_str = sess.run( [g_loss, d_loss, loss_summary_op], feed_dict=feed_dict) sum_writer.add_summary(loss_summary_str, niter) sess.run(global_step_op) progress.set_description('g_loss: %4.4f, d_loss: %4.4f' % (g_loss_np, d_loss_np)) # Test if np.mod(adv_epoch, 300) == 0 or adv_epoch == nadv_steps - 1: generator_obj.generated_num = generator_obj.batch_size * 10 json_object = generator_obj.generate_samples(sess, oracle_loader, dataset="train") write_json(json_file, json_object) json_object = generator_obj.generate_samples( sess, oracle_loader, dataset="validation") write_json(json_file_validation, json_object) # take sentences from saved files sent = take_sentences_attribute(json_object) gen_sentences_summary.write_summary(sent, adv_epoch + npre_epochs) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, adv_epoch + npre_epochs) # tqdm.write("in {} seconds".format(time.time() - t)) msg = 'pre_gen_epoch:' + str( adv_epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score tqdm.write(msg) log.write(msg) log.write('\n') gc.collect() sum_writer.close() save_model = True if save_model: model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S") model_path = os.path.join(resources_path("trained_models"), model_dir) simple_save(sess, model_path, inputs={ "x_user": x_user, "x_rating": x_rating, "x_product": x_product }, outputs={"gen_x": generator_obj.gen_x}) # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt")) print("Model saved in path: %s" % model_path)
def real_train(generator, discriminator, oracle_loader, config, args): batch_size = config['batch_size'] num_sentences = config['num_sentences'] vocab_size = config['vocab_size'] seq_len = config['seq_len'] dataset = config['dataset'] npre_epochs = config['npre_epochs'] nadv_steps = config['nadv_steps'] temper = config['temperature'] adapt = config['adapt'] # changed to match resources path data_dir = resources_path(config['data_dir']) log_dir = resources_path(config['log_dir']) sample_dir = resources_path(config['sample_dir']) # filename oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset)) gen_file = os.path.join(sample_dir, 'generator.txt') gen_text_file = os.path.join(sample_dir, 'generator_text.txt') csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv') data_file = os.path.join(data_dir, '{}.txt'.format(dataset)) if dataset == 'image_coco': test_file = os.path.join(data_dir, 'testdata/test_coco.txt') elif dataset == 'emnlp_news': data_file = os.path.join(data_dir, '{}_train.txt'.format(dataset)) test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt') else: raise NotImplementedError('Unknown dataset!') # create necessary directories if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # placeholder definitions x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real") # tokens of oracle sequences temperature = tf.Variable(1., trainable=False, name='temperature') x_real_onehot = tf.one_hot(x_real, vocab_size) # batch_size x seq_len x vocab_size assert x_real_onehot.get_shape().as_list() == [ batch_size, seq_len, vocab_size ] # generator and discriminator outputs x_fake_onehot_appr, x_fake, g_pretrain_loss, gen_o = generator( x_real=x_real, temperature=temperature) d_out_real = discriminator(x_onehot=x_real_onehot) d_out_fake = discriminator(x_onehot=x_fake_onehot_appr) # GAN / Divergence type log_pg, g_loss, d_loss = get_losses(d_out_real, d_out_fake, x_real_onehot, x_fake_onehot_appr, gen_o, discriminator, config) # Global step global_step = tf.Variable(0, trainable=False) global_step_op = global_step.assign_add(1) # Train ops g_pretrain_op, g_train_op, d_train_op = get_train_ops( config, g_pretrain_loss, g_loss, d_loss, log_pg, temperature, global_step) # Record wall clock time time_diff = tf.placeholder(tf.float32) Wall_clock_time = tf.Variable(0., trainable=False) update_Wall_op = Wall_clock_time.assign_add(time_diff) # Temperature placeholder temp_var = tf.placeholder(tf.float32) update_temperature_op = temperature.assign(temp_var) # Loss summaries loss_summaries = [ tf.summary.scalar('loss/discriminator', d_loss), tf.summary.scalar('loss/g_loss', g_loss), tf.summary.scalar('loss/log_pg', log_pg), tf.summary.scalar('loss/Wall_clock_time', Wall_clock_time), tf.summary.scalar('loss/temperature', temperature), ] loss_summary_op = tf.summary.merge(loss_summaries) # Metric Summaries metrics_pl, metric_summary_op = get_metric_summary_op(config) # Placeholder for generator pretrain loss gen_pretrain_loss_placeholder = tf.placeholder( tf.float32, name='pretrain_loss_placeholder') gen_pretrain_loss = tf.summary.scalar('generator/pretrain_loss', gen_pretrain_loss_placeholder) # Placeholder for generated text gen_sentences_placeholder = tf.placeholder( tf.string, name='generated_sentences_placeholder') gen_sentences = tf.summary.text('generator/generated_sentences', gen_sentences_placeholder) saver = tf.train.Saver() # ------------- initial the graph -------------- with init_sess() as sess: print("Inizia sessione") restore_model = False log = open(csv_file, 'w') sum_writer = tf.summary.FileWriter(os.path.join(log_dir, 'summary'), sess.graph) # generate oracle data and create batches index_word_dict = get_oracle_file(data_file, oracle_file, seq_len) print("Inizio creo batch") oracle_loader.create_batches(oracle_file) metrics = get_metrics(config, oracle_loader, test_file, gen_text_file, g_pretrain_loss, x_real, sess) print('Start pre-training...') for epoch in range(npre_epochs): print("Pretrain epoch: {}".format(epoch)) # pre-training g_pretrain_loss_np = pre_train_epoch(sess, g_pretrain_op, g_pretrain_loss, x_real, oracle_loader) pretrain_summary = sess.run( gen_pretrain_loss, feed_dict={gen_pretrain_loss_placeholder: g_pretrain_loss_np}) sum_writer.add_summary(pretrain_summary, epoch) # Test ntest_pre = 10 if np.mod(epoch, ntest_pre) == 0: # generate fake data and create batches gen_save_file = os.path.join( sample_dir, 'pre_samples_{:05d}.txt'.format(epoch)) generate_samples(sess, x_fake, batch_size, num_sentences, gen_file) get_real_test_file(gen_file, gen_save_file, index_word_dict) # qua salvo ogni volta get_real_test_file(gen_file, gen_text_file, index_word_dict) # qua sovrascrivo l'ultima # take sentences from saved files sent = take_sentences(gen_text_file) sent = random.sample(sent, 5) # pick just one sentence generated_strings_summary = sess.run( gen_sentences, feed_dict={gen_sentences_placeholder: sent}) sum_writer.add_summary(generated_strings_summary, epoch) # write summaries print("Computing Metrics and writing summaries", end=" ") t = time.time() scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, epoch) print("in {} seconds".format(time.time() - t)) msg = 'pre_gen_epoch:' + str( epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score print(msg) log.write(msg) log.write('\n') print('Start adversarial training...') progress = tqdm(range(nadv_steps)) for adv_epoch in progress: niter = sess.run(global_step) # print("Adv_epoch: {}".format(adv_epoch)) t0 = time.time() # adversarial training for _ in range(config['gsteps']): sess.run(g_train_op, feed_dict={x_real: oracle_loader.random_batch()}) for _ in range(config['dsteps']): sess.run(d_train_op, feed_dict={x_real: oracle_loader.random_batch()}) t1 = time.time() sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0}) # temperature temp_var_np = get_fixed_temperature(temper, niter, nadv_steps, adapt) sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np}) feed = {x_real: oracle_loader.random_batch()} g_loss_np, d_loss_np, loss_summary_str = sess.run( [g_loss, d_loss, loss_summary_op], feed_dict=feed) sum_writer.add_summary(loss_summary_str, niter) sess.run(global_step_op) progress.set_description('g_loss: %4.4f, d_loss: %4.4f' % (g_loss_np, d_loss_np)) # Test # print("N_iter: {}, test every {} epochs".format(niter, config['ntest'])) if np.mod(adv_epoch, 250) == 0: # generate fake data and create batches gen_save_file = os.path.join( sample_dir, 'adv_samples_{:05d}.txt'.format(niter)) generate_samples(sess, x_fake, batch_size, num_sentences, gen_file) get_real_test_file(gen_file, gen_save_file, index_word_dict) get_real_test_file(gen_file, gen_text_file, index_word_dict) # take sentences from saved files sent = take_sentences(gen_text_file) sent = random.sample(sent, 8) # pick just one sentence generated_strings_summary = sess.run( gen_sentences, feed_dict={gen_sentences_placeholder: sent}) sum_writer.add_summary(generated_strings_summary, niter + npre_epochs) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, niter + config['npre_epochs']) msg = 'adv_step: ' + str(niter) metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score print(msg) log.write(msg) log.write('\n')
def real_topic_train(generator_obj, discriminator_obj, topic_discriminator_obj, oracle_loader: RealDataTopicLoader, config, args): batch_size = config['batch_size'] num_sentences = config['num_sentences'] vocab_size = config['vocab_size'] seq_len = config['seq_len'] dataset = config['dataset'] npre_epochs = config['npre_epochs'] n_topic_pre_epochs = config['n_topic_pre_epochs'] nadv_steps = config['nadv_steps'] temper = config['temperature'] adapt = config['adapt'] # changed to match resources path data_dir = resources_path(config['data_dir']) log_dir = resources_path(config['log_dir']) sample_dir = resources_path(config['sample_dir']) # filename oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset)) gen_file = os.path.join(sample_dir, 'generator.txt') gen_text_file = os.path.join(sample_dir, 'generator_text.txt') json_file = os.path.join(sample_dir, 'json_file.txt') csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv') data_file = os.path.join(data_dir, '{}.txt'.format(dataset)) if dataset == 'image_coco': test_file = os.path.join(data_dir, 'testdata/test_coco.txt') elif dataset == 'emnlp_news': test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt') else: raise NotImplementedError('Unknown dataset!') # create necessary directories if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # placeholder definitions x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real") # tokens of oracle sequences x_topic = tf.placeholder(tf.float32, [batch_size, oracle_loader.vocab_size + 1], name="x_topic") # todo stessa cosa del +1 x_topic_random = tf.placeholder(tf.float32, [batch_size, oracle_loader.vocab_size + 1], name="x_topic_random") temperature = tf.Variable(1., trainable=False, name='temperature') x_real_onehot = tf.one_hot(x_real, vocab_size) # batch_size x seq_len x vocab_size assert x_real_onehot.get_shape().as_list() == [ batch_size, seq_len, vocab_size ] # generator and discriminator outputs generator = generator_obj(x_real=x_real, temperature=temperature, x_topic=x_topic) d_real = discriminator_obj(x_onehot=x_real_onehot) d_fake = discriminator_obj(x_onehot=generator.gen_x_onehot_adv) if not args.no_topic: d_topic_real_pos = topic_discriminator_obj(x_onehot=x_real_onehot, x_topic=x_topic) d_topic_real_neg = topic_discriminator_obj(x_onehot=x_real_onehot, x_topic=x_topic_random) d_topic_fake = topic_discriminator_obj( x_onehot=generator.gen_x_onehot_adv, x_topic=x_topic) else: d_topic_real_pos = None d_topic_real_neg = None d_topic_fake = None # GAN / Divergence type losses = get_losses(generator, d_real, d_fake, d_topic_real_pos, d_topic_real_neg, d_topic_fake, config) if not args.no_topic: d_topic_loss = losses['d_topic_loss_real_pos'] + losses[ 'd_topic_loss_real_neg'] # only from real data for pretrain d_topic_accuracy = get_accuracy(d_topic_real_pos.logits, d_topic_real_neg.logits) else: d_topic_loss = None d_topic_accuracy = None # Global step global_step = tf.Variable(0, trainable=False) global_step_op = global_step.assign_add(1) # Train ops g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops( config, generator.pretrain_loss, losses['g_loss'], losses['d_loss'], d_topic_loss, losses['log_pg'], temperature, global_step) generator.g_pretrain_op = g_pretrain_op # Record wall clock time time_diff = tf.placeholder(tf.float32) Wall_clock_time = tf.Variable(0., trainable=False) update_Wall_op = Wall_clock_time.assign_add(time_diff) # Temperature placeholder temp_var = tf.placeholder(tf.float32) update_temperature_op = temperature.assign(temp_var) # Loss summaries loss_summaries = [ tf.summary.scalar('adv_loss/discriminator/classic/d_loss_real', losses['d_loss_real']), tf.summary.scalar('adv_loss/discriminator/classic/d_loss_fake', losses['d_loss_fake']), tf.summary.scalar('adv_loss/discriminator/total', losses['d_loss']), tf.summary.scalar('adv_loss/generator/g_sentence_loss', losses['g_sentence_loss']), tf.summary.scalar('adv_loss/generator/total_g_loss', losses['g_loss']), tf.summary.scalar('adv_loss/log_pg', losses['log_pg']), tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time), tf.summary.scalar('adv_loss/temperature', temperature), ] if not args.no_topic: loss_summaries += [ tf.summary.scalar( 'adv_loss/discriminator/topic_discriminator/d_topic_loss_real_pos', losses['d_topic_loss_real_pos']), tf.summary.scalar( 'adv_loss/discriminator/topic_discriminator/d_topic_loss_fake', losses['d_topic_loss_fake']), tf.summary.scalar('adv_loss/generator/g_topic_loss', losses['g_topic_loss']) ] loss_summary_op = tf.summary.merge(loss_summaries) # Metric Summaries metrics_pl, metric_summary_op = get_metric_summary_op(config) # Summaries gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss', scope='generator') gen_sentences_summary = CustomSummary(name='generated_sentences', scope='generator', summary_type=tf.summary.text, item_type=tf.string) topic_discr_pretrain_summary = CustomSummary(name='pretrain_loss', scope='topic_discriminator') topic_discr_accuracy_summary = CustomSummary(name='pretrain_accuracy', scope='topic_discriminator') run_information = CustomSummary(name='run_information', scope='info', summary_type=tf.summary.text, item_type=tf.string) custom_summaries = [ gen_pretrain_loss_summary, gen_sentences_summary, topic_discr_pretrain_summary, topic_discr_accuracy_summary, run_information ] # To save the trained model saver = tf.compat.v1.train.Saver() # ------------- initial the graph -------------- with init_sess() as sess: variables_dict = get_parameters_division() print("Total paramter number: {}".format( np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]))) log = open(csv_file, 'w') now = datetime.datetime.now() additional_text = now.strftime( "%Y-%m-%d_%H-%M") + "_" + args.summary_name summary_dir = os.path.join(log_dir, 'summary', additional_text) if not os.path.exists(summary_dir): os.makedirs(summary_dir) sum_writer = tf.compat.v1.summary.FileWriter(os.path.join(summary_dir), sess.graph) for custom_summary in custom_summaries: custom_summary.set_file_writer(sum_writer, sess) run_information.write_summary(str(args), 0) print("Information stored in the summary!") # generate oracle data and create batches oracle_loader.create_batches(oracle_file) metrics = get_metrics(config, oracle_loader, test_file, gen_text_file, generator.pretrain_loss, x_real, x_topic, sess, json_file) gc.collect() # Check if there is a pretrained generator saved model_dir = "PretrainGenerator" model_path = resources_path( os.path.join("checkpoint_folder", model_dir)) try: new_saver = tf.train.import_meta_graph( os.path.join(model_path, "model.ckpt.meta")) new_saver.restore(sess, os.path.join(model_path, "model.ckpt")) print("Used saved model for generator pretrain") except OSError: print('Start pre-training...') progress = tqdm(range(npre_epochs), dynamic_ncols=True) for epoch in progress: # pre-training g_pretrain_loss_np = generator.pretrain_epoch( sess, oracle_loader) gen_pretrain_loss_summary.write_summary( g_pretrain_loss_np, epoch) # Test ntest_pre = 30 if np.mod(epoch, ntest_pre) == 0: json_object = generator.generate_samples_topic( sess, oracle_loader, num_sentences) write_json(json_file, json_object) # take sentences from saved files sent = generator.get_sentences(json_object) gen_sentences_summary.write_summary(sent, epoch) # write summaries t = time.time() scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, epoch) msg = 'pre_gen_epoch:' + str( epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): score = score * 1e5 if 'Earth' in name else score msg += ', ' + name + ': %.4f' % score progress.set_description( msg + " in {:.2f} sec".format(time.time() - t)) log.write(msg) log.write('\n') gc.collect() if not args.no_topic: gc.collect() print('Start Topic Discriminator pre-training...') progress = tqdm(range(n_topic_pre_epochs)) for epoch in progress: # pre-training and write loss # Pre-train the generator using MLE for one epoch supervised_g_losses = [] supervised_accuracy = [] oracle_loader.reset_pointer() for it in range(oracle_loader.num_batch): text_batch, topic_batch = oracle_loader.next_batch( only_text=False) _, topic_loss, accuracy = sess.run( [d_topic_pretrain_op, d_topic_loss, d_topic_accuracy], feed_dict={ x_real: text_batch, x_topic: topic_batch, x_topic_random: oracle_loader.random_topic() }) supervised_g_losses.append(topic_loss) supervised_accuracy.append(accuracy) d_topic_pretrain_loss = np.mean(supervised_g_losses) accuracy_mean = np.mean(supervised_accuracy) topic_discr_pretrain_summary.write_summary( d_topic_pretrain_loss, epoch) topic_discr_accuracy_summary.write_summary( accuracy_mean, epoch) progress.set_description( 'topic_loss: %4.4f, accuracy: %4.4f' % (d_topic_pretrain_loss, accuracy_mean)) print('Start adversarial training...') progress = tqdm(range(nadv_steps)) for adv_epoch in progress: gc.collect() niter = sess.run(global_step) t0 = time.time() # Adversarial training for _ in range(config['gsteps']): text_batch, topic_batch = oracle_loader.random_batch( only_text=False) sess.run(g_train_op, feed_dict={ x_real: text_batch, x_topic: topic_batch }) for _ in range(config['dsteps']): # normal + topic discriminator together text_batch, topic_batch = oracle_loader.random_batch( only_text=False) sess.run(d_train_op, feed_dict={ x_real: text_batch, x_topic: topic_batch, x_topic_random: oracle_loader.random_topic() }) t1 = time.time() sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0}) # temperature temp_var_np = get_fixed_temperature(temper, niter, nadv_steps, adapt) sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np}) text_batch, topic_batch = oracle_loader.random_batch( only_text=False) feed = { x_real: text_batch, x_topic: topic_batch, x_topic_random: oracle_loader.random_topic() } g_loss_np, d_loss_np, loss_summary_str = sess.run( [losses['g_loss'], losses['d_loss'], loss_summary_op], feed_dict=feed) sum_writer.add_summary(loss_summary_str, niter) sess.run(global_step_op) progress.set_description('g_loss: %4.4f, d_loss: %4.4f' % (g_loss_np, d_loss_np)) # Test if np.mod(adv_epoch, 400) == 0 or adv_epoch == nadv_steps - 1: json_object = generator.generate_samples_topic( sess, oracle_loader, num_sentences) write_json(json_file, json_object) # take sentences from saved files sent = generator.get_sentences(json_object) gen_sentences_summary.write_summary(sent, adv_epoch) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, niter + config['npre_epochs']) msg = 'adv_step: ' + str(niter) metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score tqdm.write(msg) log.write(msg) log.write('\n') sum_writer.close() save_model = False if save_model: model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S") model_path = os.path.join(resources_path("trained_models"), model_dir) simple_save(sess, model_path, inputs={"x_topic": x_topic}, outputs={"gen_x": generator.gen_x}) # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt")) print("Model saved in path: %s" % model_path)
def real_train_NoDiscr(generator_obj: rmc_att_topic.generator, oracle_loader, config, args): batch_size = config['batch_size'] num_sentences = config['num_sentences'] vocab_size = config['vocab_size'] seq_len = config['seq_len'] dataset = config['dataset'] npre_epochs = config['npre_epochs'] # noinspection PyUnusedLocal n_topic_pre_epochs = config['n_topic_pre_epochs'] # noinspection PyUnusedLocal nadv_steps = config['nadv_steps'] # noinspection PyUnusedLocal temper = config['temperature'] # noinspection PyUnusedLocal adapt = config['adapt'] # changed to match resources path data_dir = resources_path(config['data_dir']) log_dir = resources_path(config['log_dir']) sample_dir = resources_path(config['sample_dir']) # filename oracle_file = os.path.join(sample_dir, 'oracle_{}.txt'.format(dataset)) # noinspection PyUnusedLocal gen_file = os.path.join(sample_dir, 'generator.txt') gen_text_file = os.path.join(sample_dir, 'generator_text.txt') # noinspection PyUnusedLocal gen_text_file_print = os.path.join(sample_dir, 'gen_text_file_print.txt') json_file = os.path.join(sample_dir, 'json_file{}.txt'.format(args.json_file)) csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv') # noinspection PyUnusedLocal data_file = os.path.join(data_dir, '{}.txt'.format(dataset)) if dataset == 'image_coco': test_file = os.path.join(data_dir, 'testdata/test_coco.txt') elif dataset == 'emnlp_news': test_file = os.path.join(data_dir, 'testdata/test_emnlp.txt') else: raise NotImplementedError('Unknown dataset!') # create necessary directories if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # placeholder definitions x_real = placeholder(tf.int32, [batch_size, seq_len], name="x_real") # tokens of oracle sequences temperature = tf.Variable(1., trainable=False, name='temperature') x_real_onehot = tf.one_hot(x_real, vocab_size) # batch_size x seq_len x vocab_size assert x_real_onehot.get_shape().as_list() == [ batch_size, seq_len, vocab_size ] # generator and discriminator outputs generator = generator_obj(x_real=x_real, temperature=temperature) # Global step global_step = tf.Variable(0, trainable=False) # noinspection PyUnusedLocal global_step_op = global_step.assign_add(1) # A function to calculate the gradients and get training operations def get_train_ops(config, g_pretrain_loss): gpre_lr = config['gpre_lr'] g_vars = tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES, scope='generator') grad_clip = 5.0 # keep the same with the previous setting # generator pre-training pretrain_opt = tf.train.AdamOptimizer(gpre_lr, beta1=0.9, beta2=0.999, name="gen_pretrain_adam") pretrain_grad, _ = tf.clip_by_global_norm( tf.gradients(g_pretrain_loss, g_vars, name="gradients_g_pretrain"), grad_clip, name="g_pretrain_clipping") # gradient clipping g_pretrain_op = pretrain_opt.apply_gradients(zip( pretrain_grad, g_vars)) return g_pretrain_op # Train ops # noinspection PyUnusedLocal g_pretrain_op = get_train_ops(config, generator.pretrain_loss) # Record wall clock time time_diff = placeholder(tf.float32) Wall_clock_time = tf.Variable(0., trainable=False) # noinspection PyUnusedLocal update_Wall_op = Wall_clock_time.assign_add(time_diff) # Temperature placeholder temp_var = placeholder(tf.float32) # noinspection PyUnusedLocal update_temperature_op = temperature.assign(temp_var) # Loss summaries loss_summaries = [ tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time), tf.summary.scalar('adv_loss/temperature', temperature), ] # noinspection PyUnusedLocal loss_summary_op = tf.summary.merge(loss_summaries) # Metric Summaries metrics_pl, metric_summary_op = get_metric_summary_op(config) # Summaries gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss', scope='generator') gen_sentences_summary = CustomSummary(name='generated_sentences', scope='generator', summary_type=tf.summary.text, item_type=tf.string) run_information = CustomSummary(name='run_information', scope='info', summary_type=tf.summary.text, item_type=tf.string) custom_summaries = [ gen_pretrain_loss_summary, gen_sentences_summary, run_information ] # To save the trained model # noinspection PyUnusedLocal saver = tf.train.Saver() # ------------- initial the graph -------------- with init_sess() as sess: variables_dict = get_parameters_division() print("Total paramter number: {}".format( np.sum([ np.prod(v.get_shape().as_list()) for v in tf.trainable_variables() ]))) log = open(csv_file, 'w') # file_suffix = "date: {}, normal RelGAN, pretrain epochs: {}, adv epochs: {}".format(datetime.datetime.now(), # npre_epochs, nadv_steps) sum_writer = tf.summary.FileWriter(os.path.join( log_dir, 'summary'), sess.graph) # , filename_suffix=file_suffix) for custom_summary in custom_summaries: custom_summary.set_file_writer(sum_writer, sess) run_information.write_summary(str(args), 0) print("Information stored in the summary!") # generate oracle data and create batches # todo se le parole hanno poco senso potrebbe essere perchè qua ho la corrispondenza indice-parola sbagliata # noinspection PyUnusedLocal oracle_loader.create_batches(oracle_file) metrics = get_metrics(config, oracle_loader, test_file, gen_text_file, generator.pretrain_loss, x_real, sess, json_file) gc.collect() # Check if there is a pretrained generator saved model_dir = "PretrainGenerator" model_path = resources_path( os.path.join("checkpoint_folder", model_dir)) try: new_saver = tf.train.import_meta_graph( os.path.join(model_path, "model.ckpt.meta")) new_saver.restore(sess, os.path.join(model_path, "model.ckpt")) print("Used saved model for generator pretrain") except OSError: print('Start pre-training...') progress = tqdm(range(npre_epochs)) for epoch in progress: # pre-training g_pretrain_loss_np = generator.pre_train_epoch( sess, g_pretrain_op, oracle_loader) gen_pretrain_loss_summary.write_summary(g_pretrain_loss_np, epoch) progress.set_description( "Pretrain_loss: {}".format(g_pretrain_loss_np)) # Test ntest_pre = 40 if np.mod(epoch, ntest_pre) == 0: json_object = generator.generate_sentences( sess, batch_size, num_sentences, oracle_loader=oracle_loader) write_json(json_file, json_object) with open(gen_text_file, 'w') as outfile: i = 0 for sent in json_object['sentences']: if i < 200: outfile.write(sent['generated_sentence'] + "\n") else: break # take sentences from saved files sent = take_sentences_json(json_object, first_elem='generated_sentence', second_elem=None) gen_sentences_summary.write_summary(sent, epoch) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, epoch) # tqdm.write("in {} seconds".format(time.time() - t)) msg = 'pre_gen_epoch:' + str( epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score tqdm.write(msg) log.write(msg) log.write('\n') gc.collect() gc.collect() sum_writer.close()
def customer_reviews_train(generator: ReviewGenerator, discriminator_positive: ReviewDiscriminator, discriminator_negative: ReviewDiscriminator, oracle_loader: RealDataCustomerReviewsLoader, config, args): batch_size = config['batch_size'] num_sentences = config['num_sentences'] vocab_size = config['vocabulary_size'] seq_len = config['seq_len'] dataset = config['dataset'] npre_epochs = config['npre_epochs'] nadv_steps = config['nadv_steps'] temper = config['temperature'] adapt = config['adapt'] # changed to match resources path data_dir = resources_path(config['data_dir'], "Amazon_Attribute") log_dir = resources_path(config['log_dir']) sample_dir = resources_path(config['sample_dir']) # filename json_file = os.path.join(sample_dir, 'json_file.txt') csv_file = os.path.join(log_dir, 'experiment-log-rmcgan.csv') # create necessary directories if not os.path.exists(data_dir): os.makedirs(data_dir) if not os.path.exists(sample_dir): os.makedirs(sample_dir) if not os.path.exists(log_dir): os.makedirs(log_dir) # placeholder definitions x_real = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_real") x_pos = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_pos") x_neg = tf.placeholder(tf.int32, [batch_size, seq_len], name="x_neg") x_sentiment = tf.placeholder(tf.int32, [batch_size], name="x_sentiment") temperature = tf.Variable(1., trainable=False, name='temperature') x_real_pos_onehot = tf.one_hot( x_pos, vocab_size) # batch_size x seq_len x vocab_size x_real_neg_onehot = tf.one_hot( x_neg, vocab_size) # batch_size x seq_len x vocab_size assert x_real_pos_onehot.get_shape().as_list() == [ batch_size, seq_len, vocab_size ] # generator and discriminator outputs generator_obj = generator(x_real=x_real, temperature=temperature, x_sentiment=x_sentiment) # discriminator for positive sentences discriminator_positive_real_pos = discriminator_positive( x_onehot=x_real_pos_onehot) discriminator_positive_real_neg = discriminator_positive( x_onehot=x_real_neg_onehot) discriminator_positive_fake = discriminator_positive( x_onehot=generator_obj.gen_x_onehot_adv) # discriminator for negative sentences discriminator_negative_real_pos = discriminator_negative( x_onehot=x_real_pos_onehot) discriminator_negative_real_neg = discriminator_negative( x_onehot=x_real_neg_onehot) discriminator_negative_fake = discriminator_negative( x_onehot=generator_obj.gen_x_onehot_adv) # GAN / Divergence type log_pg, g_loss, d_loss = get_losses( generator_obj, discriminator_positive_real_pos, discriminator_positive_real_neg, discriminator_positive_fake, discriminator_negative_real_pos, discriminator_negative_real_neg, discriminator_negative_fake) # Global step global_step = tf.Variable(0, trainable=False) global_step_op = global_step.assign_add(1) # Train ops g_pretrain_op, g_train_op, d_train_op, d_topic_pretrain_op = get_train_ops( config, generator_obj.pretrain_loss, g_loss, d_loss, None, log_pg, temperature, global_step) # Record wall clock time time_diff = tf.placeholder(tf.float32) Wall_clock_time = tf.Variable(0., trainable=False) update_Wall_op = Wall_clock_time.assign_add(time_diff) # Temperature placeholder temp_var = tf.placeholder(tf.float32) update_temperature_op = temperature.assign(temp_var) # Loss summaries loss_summaries = [ tf.summary.scalar('adv_loss/discriminator/total', d_loss), tf.summary.scalar('adv_loss/generator/total_g_loss', g_loss), tf.summary.scalar('adv_loss/log_pg', log_pg), tf.summary.scalar('adv_loss/Wall_clock_time', Wall_clock_time), tf.summary.scalar('adv_loss/temperature', temperature), ] loss_summary_op = tf.summary.merge(loss_summaries) # Metric Summaries metrics_pl, metric_summary_op = get_metric_summary_op(config) # Summaries gen_pretrain_loss_summary = CustomSummary(name='pretrain_loss', scope='generator') gen_sentences_summary = CustomSummary(name='generated_sentences', scope='generator', summary_type=tf.summary.text, item_type=tf.string) run_information = CustomSummary(name='run_information', scope='info', summary_type=tf.summary.text, item_type=tf.string) custom_summaries = [ gen_pretrain_loss_summary, gen_sentences_summary, run_information ] # To save the trained model # ------------- initial the graph -------------- with init_sess() as sess: # count parameters log = open(csv_file, 'w') summary_dir = os.path.join(log_dir, 'summary', str(time.time())) if not os.path.exists(summary_dir): os.makedirs(summary_dir) sum_writer = tf.summary.FileWriter(summary_dir, sess.graph) for custom_summary in custom_summaries: custom_summary.set_file_writer(sum_writer, sess) run_information.write_summary(str(args), 0) print("Information stored in the summary!") def get_metrics(): # set up evaluation metric metrics = [] if config['nll_gen']: nll_gen = NllReview(oracle_loader, generator_obj, sess, name='nll_gen_review') metrics.append(nll_gen) if config['KL']: KL_div = KL_divergence(oracle_loader, json_file, name='KL_divergence') metrics.append(KL_div) if config['jaccard_similarity']: Jaccard_Sim = JaccardSimilarity(oracle_loader, json_file, name='jaccard_similarity') metrics.append(Jaccard_Sim) if config['jaccard_diversity']: Jaccard_Sim = JaccardDiversity(oracle_loader, json_file, name='jaccard_diversity') metrics.append(Jaccard_Sim) return metrics metrics = get_metrics() generator_obj.generated_num = 200 #num_sentences gc.collect() # Check if there is a pretrained generator saved model_dir = "PretrainGenerator" model_path = resources_path( os.path.join("checkpoint_folder", model_dir)) try: new_saver = tf.train.import_meta_graph( os.path.join(model_path, "model.ckpt.meta")) new_saver.restore(sess, os.path.join(model_path, "model.ckpt")) print("Used saved model for generator pretrain") except OSError: print('Start pre-training...') # pre-training # Pre-train the generator using MLE for one epoch progress = tqdm(range(npre_epochs)) for epoch in progress: oracle_loader.reset_pointer() g_pretrain_loss_np = generator_obj.pretrain_epoch( oracle_loader, sess, g_pretrain_op=g_pretrain_op) gen_pretrain_loss_summary.write_summary( g_pretrain_loss_np, epoch) msg = 'pre_gen_epoch:' + str( epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np progress.set_description(msg) # Test ntest_pre = 30 if np.mod(epoch, ntest_pre) == 0 or epoch == npre_epochs - 1: json_object = generator_obj.generate_json( oracle_loader, sess) write_json(json_file, json_object) # take sentences from saved files sent = take_sentences_json(json_object) gen_sentences_summary.write_summary(sent, epoch) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, epoch) msg = 'pre_gen_epoch:' + str( epoch) + ', g_pre_loss: %.4f' % g_pretrain_loss_np metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score tqdm.write(msg) log.write(msg) log.write('\n') gc.collect() gc.collect() print('Start adversarial training...') progress = tqdm(range(nadv_steps)) for adv_epoch in progress: gc.collect() niter = sess.run(global_step) t0 = time.time() # Adversarial training for _ in range(config['gsteps']): sentiment, sentence = oracle_loader.random_batch() n = np.zeros((generator_obj.batch_size, generator_obj.seq_len)) for ind, el in enumerate(sentence): n[ind] = el sess.run(g_pretrain_op, feed_dict={ generator_obj.x_real: n, generator_obj.x_sentiment: sentiment }) for _ in range(config['dsteps']): sentiment, sentence, pos, neg = oracle_loader.get_positive_negative_batch( ) n1 = np.zeros( (generator_obj.batch_size, generator_obj.seq_len)) n2 = np.zeros( (generator_obj.batch_size, generator_obj.seq_len)) n3 = np.zeros( (generator_obj.batch_size, generator_obj.seq_len)) for ind, (s, p, n) in enumerate(zip(sentence, pos, neg)): n1[ind] = s n2[ind] = p[0] n3[ind] = n[0] feed_dict = { x_real: n1, x_pos: n2, x_neg: n3, x_sentiment: sentiment } sess.run(d_train_op, feed_dict=feed_dict) t1 = time.time() sess.run(update_Wall_op, feed_dict={time_diff: t1 - t0}) # temperature temp_var_np = get_fixed_temperature(temper, niter, nadv_steps, adapt) sess.run(update_temperature_op, feed_dict={temp_var: temp_var_np}) sentiment, sentence, pos, neg = oracle_loader.get_positive_negative_batch( ) n1 = np.zeros((generator_obj.batch_size, generator_obj.seq_len)) n2 = np.zeros((generator_obj.batch_size, generator_obj.seq_len)) n3 = np.zeros((generator_obj.batch_size, generator_obj.seq_len)) for ind, (s, p, n) in enumerate(zip(sentence, pos, neg)): n1[ind] = s n2[ind] = p[0] n3[ind] = n[0] feed_dict = { x_real: n1, x_pos: n2, x_neg: n3, x_sentiment: sentiment } g_loss_np, d_loss_np, loss_summary_str = sess.run( [g_loss, d_loss, loss_summary_op], feed_dict=feed_dict) sum_writer.add_summary(loss_summary_str, niter) sess.run(global_step_op) progress.set_description('g_loss: %4.4f, d_loss: %4.4f' % (g_loss_np, d_loss_np)) # Test # print("N_iter: {}, test every {} epochs".format(niter, config['ntest'])) if np.mod(adv_epoch, 100) == 0 or adv_epoch == nadv_steps - 1: json_object = generator_obj.generate_json(oracle_loader, sess) write_json(json_file, json_object) # take sentences from saved files sent = take_sentences_json(json_object) gen_sentences_summary.write_summary( sent, niter + config['npre_epochs']) # write summaries scores = [metric.get_score() for metric in metrics] metrics_summary_str = sess.run(metric_summary_op, feed_dict=dict( zip(metrics_pl, scores))) sum_writer.add_summary(metrics_summary_str, niter + config['npre_epochs']) msg = 'adv_step: ' + str(niter) metric_names = [metric.get_name() for metric in metrics] for (name, score) in zip(metric_names, scores): msg += ', ' + name + ': %.4f' % score tqdm.write(msg) log.write(msg) log.write('\n') gc.collect() sum_writer.close() save_model = False if save_model: model_dir = datetime.datetime.now().strftime("%Y%m%d%H%M%S") model_path = os.path.join(resources_path("trained_models"), model_dir) simple_save(sess, model_path, inputs={"x_topic": x_topic}, outputs={"gen_x": x_fake}) # save_path = saver.save(sess, os.path.join(model_path, "model.ckpt")) print("Model saved in path: %s" % model_path)