def train(self, sess: tf.Session, data: DataHelper, restore_file=None, log_path='data'): actor_loss, critic_loss = self._build_loss() actor_opt = self.opt(actor_loss, 'actor') critic_opt = self.opt(critic_loss, 'critic') initial_op = tf.global_variables_initializer() sess.run(initial_op) time_string = strftime("%a%d%b%Y-%H%M%S", gmtime()) summary_path = os.path.join(log_path, 'log', time_string) + os.sep save_path = os.path.join(log_path, 'model') + os.sep if not os.path.exists(save_path): os.mkdir(save_path) if restore_file is not None: self._restore(sess, restore_file) writer = tf.summary.FileWriter(summary_path) actor_summary = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, scope='actor')) critic_summary = tf.summary.merge(tf.get_collection(tf.GraphKeys.SUMMARIES, scope='critic')) for i in range(50000): train_batch = data.next_batch('train') train_dict = {self.image_in: train_batch['batch_image']} critic_value, critic_summary_value, critic_step = sess.run( [critic_loss, critic_summary, self.global_step], feed_dict=train_dict) _, actor_value, actor_summary_value, code_value, actor_step = sess.run( [actor_opt, actor_loss, actor_summary, self.net['codes'], self.global_step], feed_dict=train_dict) writer.add_summary(critic_summary_value, critic_step) writer.add_summary(actor_summary_value, actor_step) data.update(code_value) if (i + 1) % 100 == 0: hook_train = data.hook_train() hook_summary = tf.Summary(value=[tf.Summary.Value(tag='hook/train', simple_value=hook_train)]) print('batch {}: actor {}, critic {}'.format(i, actor_value, critic_value)) writer.add_summary(hook_summary, actor_step) if (i + 1) % 3000 == 0: print('Testing!!!!!!!!') test_batch = data.next_batch('test') test_dict = {self.image_in: test_batch['batch_image']} test_code = sess.run(self.net['codes'], feed_dict=test_dict) data.update(test_code, phase='test') hook_test = data.hook_test() hook_summary = tf.Summary(value=[tf.Summary.Value(tag='hook/test', simple_value=hook_test)]) writer.add_summary(hook_summary, actor_step) if (i + 1) % 3000 == 0: self._save(sess, save_path, actor_step)
def extract_image(self, sess: tf.Session, data: DataHelper): import numpy as np import scipy.io as sio rslt = np.zeros((data.training_data.batch_num * data.training_data.batch_size, self.input_length)) for i in range(data.training_data.batch_num + 1): this_batch = data.next_batch('train') this_dict = {self.image_in: this_batch['batch_image']} code = sess.run(self.net['decode_result'], feed_dict=this_dict) rslt[this_batch['batch_start']:this_batch['batch_end']] = code to_save = {'set_image': rslt, 'set_label': data.training_data.label} sio.savemat('/home/ymcidence/Workspace/CodeGeass/MatlabWorkspace/sgh_code/mnist_sgh.mat', to_save)
def extract(self, sess: tf.Session, data: DataHelper, log_path='data', task='cifar'): for i in range(data.training_data.batch_num + 1): this_batch = data.next_batch('train') this_dict = {self.image_in: this_batch['batch_image']} code = sess.run(self.net['codes'], feed_dict=this_dict) data.update(code) for i in range(data.test_data.batch_num + 1): this_batch = data.next_batch('test') this_dict = {self.image_in: this_batch['batch_image']} code = sess.run(self.net['codes'], feed_dict=this_dict) data.update(code, 'test') data.save(task, self.code_length, folder=log_path)
if (i + 1) % 3000 == 0: self._save(sess, save_path, actor_step) if __name__ == '__main__': from util.data.dataset import MatDataset batch_size = 200 code_length = 32 train_file = '/home/ymcidence/Workspace/CodeGeass/MatlabWorkspace/train_mnist.mat' test_file = '/home/ymcidence/Workspace/CodeGeass/MatlabWorkspace/test_mnist.mat' model_config = {'batch_size': batch_size, 'code_length': code_length, 'input_length': 784} train_config = {'batch_size': batch_size, 'code_length': code_length, 'file_name': train_file, 'phase': 'train'} train_config2 = {'batch_size': batch_size, 'code_length': code_length, 'file_name': train_file, 'phase': 'test'} test_config = {'batch_size': batch_size, 'code_length': code_length, 'file_name': test_file, 'phase': 'test'} this_sess = tf.Session() model = SGH(**model_config) train_data = MatDataset(**train_config) test_data = MatDataset(**test_config) data_helper = DataHelper(train_data, test_data) train_data2 = MatDataset(**train_config2) data_helper2 = DataHelper(train_data2, test_data) model.train(this_sess, data_helper) model.extract_image(this_sess, data_helper2)
'code_length': code_length, 'file_name': train_file, 'phase': 'train' } test_config = { 'batch_size': batch_size, 'code_length': code_length, 'file_name': test_file, 'phase': 'train' } base_config = { 'batch_size': batch_size, 'code_length': code_length, 'file_name': base_file, 'phase': 'train' } sess = tf.Session() model = BasicModel(**model_config) train_data = H5Dataset(**train_config) test_data = H5Dataset(**test_config) test_data1 = H5Dataset(**test_config) base_data = H5Dataset(**base_config) data_helper = DataHelper(train_data, test_data) base_helper = DataHelper(base_data, test_data1) model.train(sess, data_helper, log_path='../../Log') model.extract(sess, base_helper, log_path='../../Log', task=task)