def init_real_training(self, data_loc=None, with_image=True): self.sequence_length, self.vocab_size, vocabulary = process_train_data( self.config, data_loc, has_image=with_image) ##self.sequence_length, self.vocab_size, index_word_dict = text_precess(data_loc, oracle_file=self.config.temp_oracle_file) print("sequence length:", self.sequence_length, " vocab size:", self.vocab_size) goal_out_size = sum(self.num_filters) discriminator = Discriminator(self.config) self.set_discriminator(discriminator) generator = Generator(self.config, D_model=discriminator) self.set_generator(generator) # data loader for generator and discriminator gen_dataloader = DataLoader(self.config, batch_size=self.batch_size, seq_length=self.sequence_length) gen_dataloader.create_shuffled_batches(with_image) #gen_dataloader.create_shuffled_batches() oracle_dataloader = None dis_dataloader = DisDataloader(self.config, batch_size=self.batch_size, seq_length=self.sequence_length) self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader) #print("done initializing training") return vocabulary
def init_oracle_trainng(self, oracle=None): goal_out_size = sum(self.num_filters) if oracle is None: oracle = OracleLstm(num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, hidden_dim=self.hidden_dim, sequence_length=self.sequence_length, start_token=self.start_token) self.set_oracle(oracle) discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size, dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size, num_filters=self.num_filters, batch_size=self.batch_size, hidden_dim=self.hidden_dim, start_token=self.start_token, goal_out_size=goal_out_size, step_size=4, l2_reg_lambda=self.l2_reg_lambda) self.set_discriminator(discriminator) generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size, hidden_dim=self.hidden_dim, sequence_length=self.sequence_length, filter_sizes=self.filter_size, start_token=self.start_token, num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=discriminator, step_size=4) self.set_generator(generator) gen_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length) oracle_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length) dis_dataloader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length) config = tf.compat.v1.ConfigProto() config.gpu_options.allow_growth = True config.gpu_options.per_process_gpu_memory_fraction = 0.5 self.sess = tf.compat.v1.Session(config=config) self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader)
def init_real_trainng(self, data_loc=None): from utils.text_process import text_precess, text_to_code from utils.text_process import get_tokenlized, get_word_list, get_dict if data_loc is None: data_loc = 'data/image_coco.txt' self.sequence_length, self.vocab_size = text_precess(data_loc) goal_out_size = sum(self.num_filters) discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size, dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size, num_filters=self.num_filters, batch_size=self.batch_size, hidden_dim=self.hidden_dim, start_token=self.start_token, goal_out_size=goal_out_size, step_size=4, l2_reg_lambda=self.l2_reg_lambda) self.set_discriminator(discriminator) generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size, hidden_dim=self.hidden_dim, sequence_length=self.sequence_length, filter_sizes=self.filter_size, start_token=self.start_token, num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=discriminator, step_size=4) self.set_generator(generator) gen_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length) oracle_dataloader = None dis_dataloader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length) self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader) tokens = get_tokenlized(data_loc) word_set = get_word_list(tokens) [word_index_dict, index_word_dict] = get_dict(word_set) with open(self.oracle_file, 'w') as outfile: outfile.write(text_to_code(tokens, word_index_dict, self.sequence_length)) return word_index_dict, index_word_dict
def init_cfg_training(self, grammar=None): from utils.oracle.OracleCfg import OracleCfg oracle = OracleCfg(sequence_length=self.sequence_length, cfg_grammar=grammar) self.set_oracle(oracle) self.oracle.generate_oracle() self.vocab_size = self.oracle.vocab_size + 1 goal_out_size = sum(self.num_filters) discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size, dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size, num_filters=self.num_filters, batch_size=self.batch_size, hidden_dim=self.hidden_dim, start_token=self.start_token, goal_out_size=goal_out_size, step_size=4, l2_reg_lambda=self.l2_reg_lambda) self.set_discriminator(discriminator) generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size, hidden_dim=self.hidden_dim, sequence_length=self.sequence_length, filter_sizes=self.filter_size, start_token=self.start_token, num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=discriminator, step_size=4) self.set_generator(generator) gen_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length) oracle_dataloader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length) dis_dataloader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length) self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader) return oracle.wi_dict, oracle.iw_dict
def init_real_training(self, data_loc=None, with_image=True): from utils.text_process import text_precess, text_to_code, process_train_data from utils.text_process import get_tokenlized, get_word_list, get_dict self.sequence_length, self.vocab_size, vocabulary = process_train_data( self.config, data_loc, has_image=with_image) ##self.sequence_length, self.vocab_size, index_word_dict = text_precess(data_loc, oracle_file=self.config.temp_oracle_file) print("sequence length:", self.sequence_length, " vocab size:", self.vocab_size) goal_out_size = sum(self.num_filters) discriminator = Discriminator(sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size, dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size, num_filters=self.num_filters, batch_size=self.batch_size, hidden_dim=self.hidden_dim, start_token=self.start_token, goal_out_size=goal_out_size, step_size=4, l2_reg_lambda=self.l2_reg_lambda) self.set_discriminator(discriminator) generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size, hidden_dim=self.hidden_dim, sequence_length=self.sequence_length, filter_sizes=self.filter_size, start_token=self.start_token, num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=discriminator, step_size=4) self.set_generator(generator) # data loader for generator and discriminator gen_dataloader = DataLoader(self.config, batch_size=self.batch_size, seq_length=self.sequence_length) gen_dataloader.create_batches_v2(self.config, with_image) oracle_dataloader = None dis_dataloader = DisDataloader(self.config, batch_size=self.batch_size, seq_length=self.sequence_length) self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader) #print("done initializing training") return vocabulary
def val(self, data_loc=None, with_image=True): goal_out_size = sum(self.num_filters) self.sequence_length, self.vocab_size, vocabulary = process_val_data( self.config) discriminator = Discriminator(self.config) self.set_discriminator(discriminator) generator = Generator(self.config, D_model=discriminator) self.set_generator(generator) # data loader for generator and discriminator gen_dataloader = DataEvalLoader(self.config, batch_size=self.batch_size) gen_dataloader.create_batches(with_image) #gen_dataloader.create_shuffled_batches() oracle_dataloader = None dis_dataloader = DisDataloader(self.config, batch_size=self.batch_size, seq_length=self.sequence_length) self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader) self.restore_model(self.sess) #self.sess.run(tf.global_variables_initializer()) #self.context_file = self.config.temp_generate_eval_file image_files, codes = generate_samples_gen(self.sess, self.generator, self.gen_data_loader, self.batch_size, self.config.num_eval_samples, eval=True, test=True) generated_samples = [] for code in codes: #print(code) code = vocabulary.code_to_text([code]) code = self.remove_padding(code) generated_samples.append(code) np.save(self.config.temp_generate_eval_file, generated_samples) ids = [] for img in image_files: #print(img) jpg_idx = img.find('.jpg') #print(str(jpg_idx)) ids.append(int(img[45:jpg_idx])) np.save(self.config.temp_eval_id, ids) prepare_json(self.config)
def test(self, data_loc=None, with_image=True): goal_out_size = sum(self.num_filters) self.sequence_length, self.vocab_size, vocabulary = process_test_data( self.config) discriminator = Discriminator(self.config) self.set_discriminator(discriminator) generator = Generator(self.config, D_model=discriminator) self.set_generator(generator) # data loader for generator and discriminator gen_dataloader = DataEvalLoader(self.config, batch_size=self.batch_size) gen_dataloader.create_batches(with_image) #gen_dataloader.create_shuffled_batches() oracle_dataloader = None dis_dataloader = DisDataloader(self.config, batch_size=self.batch_size, seq_length=self.sequence_length) self.set_data_loader(gen_loader=gen_dataloader, dis_loader=dis_dataloader, oracle_loader=oracle_dataloader) self.restore_model(self.sess) #self.sess.run(tf.global_variables_initializer()) self.context_file = self.config.temp_generate_eval_file codes = generate_samples_gen(self.sess, self.generator, self.gen_data_loader, self.batch_size, self.batch_size, self.generator_file, test=True) samples = vocabulary.code_to_text(codes) print(np.array(samples).shape) samples = self.remove_padding(samples) print(np.array(samples).shape) print(samples) results_writer = open(self.config.test_result_file, 'w') for samp in samples: results_writer.write(samp) results_writer.close()