class GenericModel: def __init__(self, dirpath): self.optimizer = None self.learning_rate = None self.batch_size = None self.metrics = None self.loss = None self.epochs = None self.dirpath = dirpath self.tensorboard = self.create_tb_callbacks("./tensorboards/"+self.dirpath+'/'+type(self).__name__) self.name_model = os.getcwd() + '/saved_models/'+self.dirpath+'/'+type(self).__name__+".h5" self.data_gen = DataLoader(dirpath=dirpath, batch_size=self.batch_size, downsample_factor=0) self.data_gen.build_data() self.output_size = self.data_gen.get_output_size() img_w = self.data_gen.img_w img_h = self.data_gen.img_h if K.image_data_format() == 'channels_first': self.input_shape = (1, img_w, img_h) else: self.input_shape = (img_w, img_h, 1) @staticmethod def create_input(name, shape, dtype="float32"): return Input(name=name, shape=shape, dtype=dtype) @staticmethod def ctc_loss(): return {'ctc': lambda y_true, y_pred: y_pred} @staticmethod def convolution_maxpooling(layer, conv_filters, kernel_size, name_conv, name_pool, pool_size, padding='same', activation='relu', kernel_initializer='he_normal'): inner = Conv2D(conv_filters, kernel_size, padding=padding, activation=activation, kernel_initializer=kernel_initializer, name=name_conv)(layer) return MaxPooling2D(pool_size=(pool_size, pool_size), name=name_pool)(inner) @staticmethod def bi_lstm(layer, h_size, name, return_sequences=True, kernel_initializer='he_normal', merge_method="add"): lstm_1 = LSTM(h_size, return_sequences=return_sequences, kernel_initializer=kernel_initializer, name=name)( layer) lstm_1b = LSTM(h_size, return_sequences=return_sequences, go_backwards=True, kernel_initializer=kernel_initializer, name=name + 'b')(layer) if merge_method == "add": return add([lstm_1, lstm_1b]) elif merge_method == "concatenate": return concatenate([lstm_1, lstm_1b]) elif merge_method == None: return lstm_1, lstm_1b else: print("You must give a method in order to merge the two directional layers") raise Exception @staticmethod def ctc_lambda_func(args): y_pred, labels, input_length, label_length = args y_pred = y_pred[:, 2:, :] return K.ctc_batch_cost(labels, y_pred, input_length, label_length) @staticmethod def from_conv_to_lstm_reshape(layer, name="reshape"): conv_to_rnn_dims = (layer.get_shape().as_list()[1], (layer.get_shape().as_list()[2]) * layer.get_shape().as_list()[3]) return Reshape(target_shape=conv_to_rnn_dims, name=name)(layer) @staticmethod def ctc_layer(y_pred, max_output_len, name_input_length, name_label, name_label_length, name_loss): labels = Input(name=name_label, shape=[max_output_len], dtype='float32') input_length = Input(name=name_input_length, shape=[1], dtype='int64') label_length = Input(name=name_label_length, shape=[1], dtype='int64') return labels, input_length, label_length, Lambda(GenericModel.ctc_lambda_func, output_shape=(1,), name=name_loss)( [y_pred, labels, input_length, label_length]) @staticmethod def create_tb_callbacks(tensorboard_dir): return TensorBoard(log_dir=tensorboard_dir, histogram_freq=0, write_graph=True, write_images=True) def build_model(self): raise NotImplementedError @staticmethod def load_model(loss, metrics, opt, name_model): model = load_model(name_model, compile=False) return model.compile(loss=loss, optimizer=opt, metrics=metrics) def initialize_training(self): raise NotImplementedError def train(self, model, tensorboard_callback, loss, metrics, nb_epochs, save, opt, lr): model.compile(loss=loss, optimizer=opt(lr), metrics=metrics) history = model.fit_generator(generator=self.data_gen.next_batch(mode="train", batch_size=self.batch_size), steps_per_epoch=self.data_gen.n["train"], epochs=nb_epochs, callbacks=[tensorboard_callback], validation_data=self.data_gen.next_batch(mode="test", batch_size=self.batch_size), validation_steps=self.data_gen.n["test"]) if save: print("saving model into : ") if not os.path.exists(os.getcwd() + '/saved_models/'+self.dirpath): os.makedirs(os.getcwd() + '/saved_models/'+self.dirpath) if os.path.exists(os.getcwd() + '/saved_models/'+self.dirpath+'/'+type(self).__name__+".h5") : print("model already saved a long time ago ") raise Exception model.save(os.getcwd() + '/saved_models/'+self.dirpath+'/'+type(self).__name__+".h5") return model, history def run_model(self, save=False, load=False): try: self.initialize_training() except Exception: print("you need to over-load the method initialize_training in your model ") raise Exception if self.optimizer is None: print("please provide an optimizer") raise Exception if self.learning_rate is None: print("please provide a learning_rate") raise Exception if self.metrics is None: print("please provide metrics") raise Exception if self.loss is None: print("please provide a loss function") raise Exception if self.epochs is None: print("please provide a number of epochs") raise Exception if load: print("Loading model") model = GenericModel.load_model(loss=self.loss, metrics=self.metrics, opt=self.optimizer, name_model=self.name_model) history = [] else: model = self.build_model() with open(os.getcwd() + '/summaries/'+type(self).__name__+'.txt','w') as fh: model.summary(print_fn=lambda x: fh.write(x + '\n')) model, history = self.train(model=model, tensorboard_callback=self.tensorboard, loss=self.loss, metrics=self.metrics, nb_epochs=self.epochs, save=save, opt=self.optimizer, lr=self.learning_rate) with open(os.getcwd()+'/logs/'+type(self).__name__+'.json','w') as log_file: log = {} log["name"] = type(self).__name__ log["batch_size"] = self.batch_size log["optimizer"]= self.optimizer.__name__ log["learning_rate"] = self.learning_rate log["epochs"] = self.epochs log["nb_train"]= self.data_gen.n["train"] log["nb_test"]= self.data_gen.n["test"] log["data_dim"]= [int(self.input_shape[0]), int(self.input_shape[1]),int( self.input_shape[2])] var_log = {} for keys_indicators in history.history.keys(): var_log[keys_indicators] = history.history[keys_indicators] print(var_log) log["train"]={} for i in range(self.epochs): log["train"][str(i)]= {} for keys_indicators in history.history.keys(): log["train"][str(i)][keys_indicators] = var_log[keys_indicators][i] log["max_values"] = {} for keys_indicators in history.history.keys(): log["max_values"][keys_indicators] = sorted(var_log[keys_indicators])[-1] json_string = model.to_json() log["summary"]=json_string json.dump(log, log_file, indent=4) return model, history
class LeakGan(Gan): def __init__(self, wi_dict_path, iw_dict_path, train_data, val_data=None): super().__init__() self.vocab_size = 20 self.emb_dim = 64 self.hidden_dim = 64 self.input_length = 8 self.sequence_length = 32 self.filter_size = [2, 3] self.num_filters = [100, 200] self.l2_reg_lambda = 0.2 self.dropout_keep_prob = 0.75 self.batch_size = 64 self.generate_num = 256 self.start_token = 0 self.dis_embedding_dim = 64 self.goal_size = 16 self.save_path = 'save/model/LeakGan/LeakGan' self.model_path = 'save/model/LeakGan' self.best_path_pre = 'save/model/best-pre-gen/best-pre-gen' self.best_path = 'save/model/best-leak-gan/best-leak-gan' self.best_model_path = 'save/model/best-leak-gan' self.truth_file = 'save/truth.txt' self.generator_file = 'save/generator.txt' self.test_file = 'save/test_file.txt' self.trunc_train_file = 'save/trunc_train.txt' self.trunc_val_file = 'save/trunc_val.txt' trunc_data(train_data, self.trunc_train_file, self.input_length) trunc_data(val_data, self.trunc_val_file, self.input_length) if not os.path.isfile(wi_dict_path) or not os.path.isfile( iw_dict_path): print('Building word/index dictionaries...') self.sequence_length, self.vocab_size, word_index_dict, index_word_dict = text_precess( train_data, val_data) print('Vocab Size: %d' % self.vocab_size) print('Saving dictionaries to ' + wi_dict_path + ' ' + iw_dict_path + '...') with open(wi_dict_path, 'wb') as f: pickle.dump(word_index_dict, f) with open(iw_dict_path, 'wb') as f: pickle.dump(index_word_dict, f) else: print('Loading word/index dectionaries...') with open(wi_dict_path, 'rb') as f: word_index_dict = pickle.load(f) with open(iw_dict_path, 'rb') as f: index_word_dict = pickle.load(f) self.vocab_size = len(word_index_dict) + 1 print('Vocab Size: %d' % self.vocab_size) self.wi_dict = word_index_dict self.iw_dict = index_word_dict self.train_data = train_data self.val_data = val_data goal_out_size = sum(self.num_filters) self.discriminator = Discriminator( sequence_length=self.sequence_length, num_classes=2, vocab_size=self.vocab_size, dis_emb_dim=self.dis_embedding_dim, filter_sizes=self.filter_size, num_filters=self.num_filters, batch_size=self.batch_size, hidden_dim=self.hidden_dim, start_token=self.start_token, goal_out_size=goal_out_size, step_size=4, l2_reg_lambda=self.l2_reg_lambda) self.generator = Generator(num_classes=2, num_vocabulary=self.vocab_size, batch_size=self.batch_size, emb_dim=self.emb_dim, dis_emb_dim=self.dis_embedding_dim, goal_size=self.goal_size, hidden_dim=self.hidden_dim, sequence_length=self.sequence_length, input_length=self.input_length, filter_sizes=self.filter_size, start_token=self.start_token, num_filters=self.num_filters, goal_out_size=goal_out_size, D_model=self.discriminator, step_size=4) self.saver = tf.train.Saver() self.best_pre_saver = tf.train.Saver() self.best_saver = tf.train.Saver() self.val_bleu1 = Bleu(real_text=self.trunc_val_file, gram=1) self.val_bleu2 = Bleu(real_text=self.trunc_val_file, gram=2) def train_discriminator(self): generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.gen_data_loader, self.generator_file) self.dis_data_loader.load_train_data(self.truth_file, self.generator_file) for _ in range(3): self.dis_data_loader.next_batch() x_batch, y_batch = self.dis_data_loader.next_batch() feed = { self.discriminator.D_input_x: x_batch, self.discriminator.D_input_y: y_batch, } _, _ = self.sess.run( [self.discriminator.D_loss, self.discriminator.D_train_op], feed) self.generator.update_feature_function(self.discriminator) def eval(self): generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.gen_data_loader, self.generator_file) if self.log is not None: if self.epoch == 0 or self.epoch == 1: for metric in self.metrics: self.log.write(metric.get_name() + ',') self.log.write('\n') scores = super().eval() for score in scores: self.log.write(str(score) + ',') self.log.write('\n') return scores return super().eval() def init_metric(self): # docsim = DocEmbSim(oracle_file=self.truth_file, generator_file=self.generator_file, # num_vocabulary=self.vocab_size) # self.add_metric(docsim) inll = Nll(data_loader=self.gen_data_loader, rnn=self.generator, sess=self.sess) inll.set_name('nll-test') self.add_metric(inll) bleu1 = Bleu(test_text=self.test_file, real_text=self.trunc_train_file, gram=1) bleu1.set_name('BLEU-1') self.add_metric(bleu1) bleu2 = Bleu(test_text=self.test_file, real_text=self.trunc_train_file, gram=2) bleu2.set_name('BLEU-2') self.add_metric(bleu2) def train(self, restore=False, model_path=None): self.gen_data_loader = DataLoader(batch_size=self.batch_size, seq_length=self.sequence_length, input_length=self.input_length) self.dis_data_loader = DisDataloader(batch_size=self.batch_size, seq_length=self.sequence_length) tokens = get_tokens(self.train_data) with open(self.truth_file, 'w', encoding='utf-8') as outfile: outfile.write( text_to_code(tokens, self.wi_dict, self.sequence_length)) wi_dict, iw_dict = self.wi_dict, self.iw_dict self.init_metric() def get_real_test_file(dict=iw_dict): codes = get_tokens(self.generator_file) with open(self.test_file, 'w', encoding='utf-8') as outfile: outfile.write( code_to_text(codes=codes[self.input_length:], dict=dict)) if restore: self.pre_epoch_num = 0 if model_path is not None: self.model_path = model_path savefile = tf.train.latest_checkpoint(self.model_path) self.saver.restore(self.sess, savefile) else: self.sess.run(tf.global_variables_initializer()) self.pre_epoch_num = 80 # self.adversarial_epoch_num = 100 self.log = open('log/experiment-log.txt', 'w', encoding='utf-8') self.gen_data_loader.create_batches(self.truth_file) generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.gen_data_loader, self.generator_file) self.gen_data_loader.reset_pointer() for a in range(1): inputs, target = self.gen_data_loader.next_batch() g = self.sess.run(self.generator.gen_x, feed_dict={ self.generator.drop_out: 1, self.generator.train: 1, self.generator.inputs: inputs }) print('start pre-train generator:') best = 0 for epoch in range(self.pre_epoch_num): start = time() loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader) end = time() print('epoch:' + str(self.epoch) + '\t time:' + str(end - start)) self.epoch += 1 if epoch % 5 == 0: generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.gen_data_loader, self.generator_file) get_real_test_file() scores = self.eval() self.saver.save(self.sess, self.save_path, global_step=epoch) if scores[3] > best: print('--- Saving best-pre-gen...') best = scores[3] self.best_pre_saver.save(self.sess, self.best_path_pre, global_step=epoch) print('start pre-train discriminator:') # self.epoch = 0 for epoch in range(self.pre_epoch_num): print('epoch:' + str(epoch)) self.train_discriminator() self.saver.save(self.sess, self.save_path, global_step=self.pre_epoch_num * 2) self.epoch = 0 best = 0 self.reward = Reward(model=self.generator, dis=self.discriminator, sess=self.sess, rollout_num=4) for epoch in range(self.adversarial_epoch_num // 10): for epoch_ in range(10): print('epoch:' + str(epoch) + '--' + str(epoch_)) start = time() for index in range(1): inputs, target = self.gen_data_loader.next_batch() samples = self.generator.generate(self.sess, 1, inputs=inputs) rewards = self.reward.get_reward(samples, inputs) feed = { self.generator.x: samples, self.generator.reward: rewards, self.generator.drop_out: 1, self.generator.inputs: inputs } _, _, g_loss, w_loss = self.sess.run([ self.generator.manager_updates, self.generator.worker_updates, self.generator.goal_loss, self.generator.worker_loss, ], feed_dict=feed) print('epoch', str(epoch), 'g_loss', g_loss, 'w_loss', w_loss) end = time() self.epoch += 1 print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start)) if self.epoch % 5 == 0 or self.epoch == self.adversarial_epoch_num - 1: generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.gen_data_loader, self.generator_file) get_real_test_file() scores = self.eval() print('--- Generating poem on val data... ') target_file = 'save/gen_val/val_%d_%f.txt' % (self.epoch, scores[1]) self.infer(test_data=self.val_data, target_path=target_file, model_path=self.model_path, restore=False, trunc=True) self.val_bleu1.test_data = target_file self.val_bleu2.test_data = target_file bleu1, bleu2 = self.val_bleu1.get_score( ), self.val_bleu2.get_score() print('--- BLEU on val data: \t bleu1: %f \t bleu2: %f' % (bleu1, bleu2)) if bleu2 > best: best = bleu2 print('--- Saving best-leak-gan...') self.best_saver.save(self.sess, self.best_path, global_step=epoch * 10 + epoch_) for _ in range(15): self.train_discriminator() self.saver.save(self.sess, self.save_path, global_step=1 + epoch + self.pre_epoch_num * 2) for epoch_ in range(5): start = time() loss = pre_train_epoch_gen(self.sess, self.generator, self.gen_data_loader) end = time() print('epoch:' + str(epoch) + '--' + str(epoch_) + '\t time:' + str(end - start)) if epoch % 5 == 0: generate_samples_gen(self.sess, self.generator, self.batch_size, self.generate_num, self.gen_data_loader, self.generator_file) get_real_test_file() self.eval() for epoch_ in range(5): print('epoch:' + str(epoch) + '--' + str(epoch_)) self.train_discriminator() def infer(self, test_data, target_path, model_path, restore=True, trunc=False): if model_path is None: model_path = self.model_path if restore: savefile = tf.train.latest_checkpoint(model_path) self.saver.restore(self.sess, savefile) tokens = get_tokens(test_data) sentence_num = len(tokens) temp_file = 'save/infer_temp.txt' with open(temp_file, 'w', encoding='utf-8') as outfile: outfile.write(text_to_code(tokens, self.wi_dict, self.input_length)) test_data_loader = TestDataloader(batch_size=self.batch_size, input_length=self.input_length) test_data_loader.create_batches(temp_file) generate_samples_gen(self.sess, self.generator, self.batch_size, test_data_loader.num_batch * self.batch_size, test_data_loader, target_path) codes = get_tokens(target_path)[:sentence_num] with open(target_path, 'w', encoding='utf-8') as outfile: if trunc: outfile.write( code_to_text(codes=codes[self.input_length:], dict=self.iw_dict)) else: outfile.write(code_to_text(codes=codes, dict=self.iw_dict)) print('Finished generating %d poems to %s' % (sentence_num, target_path))