def rollback(var_list, ckpt_folder, ckpt_file=None): """ This function provides a shortcut for reloading a model and calculating a list of variables :param var_list: :param ckpt_folder: :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :return: """ global_step = global_step_config() # register a session sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True, log_device_placement=False)) # initialization init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) sess.run(init_op) # load the training graph saver = tf.compat.v1.train.Saver(max_to_keep=2) ckpt = get_ckpt(ckpt_folder, ckpt_file=ckpt_file) if ckpt is None: raise FileNotFoundError( 'No ckpt Model found at {}.'.format(ckpt_folder)) saver.restore(sess, ckpt.model_checkpoint_path) FLAGS.print('Model reloaded.') # run the session coord = tf.train.Coordinator() # threads = tf.train.start_queue_runners(sess=sess, coord=coord) var_value, global_step_value = sess.run([var_list, global_step]) coord.request_stop() # coord.join(threads) sess.close() FLAGS.print('Variable calculated.') return var_value, global_step_value
def print_loss(loss_value, step=0, epoch=0): FLAGS.print('Epoch {}, global steps {}, loss_list {}'.format( epoch, step, ['{}'.format(['<{:.2f}>'.format(l_val) for l_val in l_value]) if isinstance(l_value, (np.ndarray, list)) else '<{:.3f}>'.format(l_value) for l_value in loss_value]))
def run_m_times( self, var_list, ckpt_folder=None, ckpt_file=None, max_iter=10000, trace=False, ckpt_var_list=None, feed_dict=None): """ This functions calculates var_list for multiple iterations, as often done in Monte Carlo analysis. :param var_list: :param ckpt_folder: :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :param max_iter: :param trace: if True, keep all outputs of m iterations :param ckpt_var_list: the variable to load in order to calculate var_list :param feed_dict: :return: """ if ckpt_var_list is not None: self.ckpt_var_list = ckpt_var_list self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file) self._check_thread_() extra_update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) start_time = time.time() if trace: var_value_list = [] for i in range(max_iter): var_value, _ = self.sess.run([var_list, extra_update_ops], feed_dict=feed_dict) var_value_list.append(var_value) else: for i in range(max_iter - 1): _, _ = self.sess.run([var_list, extra_update_ops], feed_dict=feed_dict) var_value_list, _ = self.sess.run([var_list, extra_update_ops], feed_dict=feed_dict) # global_step_value = self.sess.run([self.global_step]) FLAGS.print('Calculation took {:.3f} sec.'.format(time.time() - start_time)) return var_value_list
def write_sprite(sprite_path, images, mesh_num=None, if_invert=False): """ This function writes images to sprite image for embedding This function was taken from: https://github.com/oduerr/dl_tutorial/blob/master/tensorflow/debugging/embedding.ipynb The input image must be channels_last format. :param sprite_path: file name, e.g. '...\\a_sprite.png' :param images: ndarray, [batch_size, height, width(, channels)], values in range [0,1] :param if_invert: bool, if true, invert images: images = 1 - images :param mesh_num: nums of images in the row and column, must be a tuple :return: """ if len(images.shape) == 3: # if dimension of image is 3, extend it to 4 images = np.tile(images[..., np.newaxis], (1, 1, 1, 3)) if images.shape[3] == 1: # if last dimension is 1, extend it to 3 images = np.tile(images, (1, 1, 1, 3)) # scale image to range [0,1] images = images.astype(np.float32) image_min = np.min(images.reshape((images.shape[0], -1)), axis=1) images = (images.transpose((1, 2, 3, 0)) - image_min).transpose( (3, 0, 1, 2)) image_max = np.max(images.reshape((images.shape[0], -1)), axis=1) images = (images.transpose((1, 2, 3, 0)) / image_max).transpose( (3, 0, 1, 2)) if if_invert: images = 1 - images # check mesh_num if mesh_num is None: FLAGS.print('Mesh_num will be calculated as sqrt of batch_size') batch_size = images.shape[0] sprite_size = int(np.ceil(np.sqrt(batch_size))) mesh_num = (sprite_size, sprite_size) # add paddings if batch_size is not the square of sprite_size padding = ((0, sprite_size**2 - batch_size), (0, 0), (0, 0)) + ((0, 0), ) * (images.ndim - 3) images = np.pad(images, padding, mode='constant', constant_values=0) elif isinstance(mesh_num, list): mesh_num = tuple(mesh_num) # Tile the individual thumbnails into an image new_shape = mesh_num + images.shape[1:] images = images.reshape(new_shape).transpose( (0, 2, 1, 3) + tuple(range(4, images.ndim + 1))) images = images.reshape((mesh_num[0] * images.shape[1], mesh_num[1] * images.shape[3]) + images.shape[4:]) images = (images * 255).astype(np.uint8) # save images to file # from scipy.misc import imsave # imsave(sprite_path, images) try: from imageio import imwrite imwrite(sprite_path, images) except: print('attempt to write image failed!')
def __exit__(self, exc_type, exc_val, exc_tb): """ The exit method is called when leaving the scope of "with" statement :param exc_type: :param exc_val: :param exc_tb: :return: """ FLAGS.print('Session finished.') if self.summary_writer is not None: self.summary_writer.close() self.coord.request_stop() # self.coord.join(self.threads) self.sess.close()
def __init__( self, do_save=False, do_trace=False, save_path=None, load_ckpt=False, log_device=False, ckpt_var_list=None): """ This class provides shortcuts for running sessions. It needs to be run under context managers. Example: with MySession() as sess: var1_value, var2_value = sess.run_once([var1, var2]) :param do_save: :param do_trace: :param save_path: :param load_ckpt: :param log_device: :param ckpt_var_list: list of variables to save / restore """ # somehow it gives error: "global_step does not exist or is not created from tf.get_variable". # self.global_step = global_step_config() self.log_device = log_device # register a session # self.sess = tf.Session(config=tf.ConfigProto( # allow_soft_placement=True, # log_device_placement=log_device, # gpu_options=tf.GPUOptions(allow_growth=True))) self.sess = tf.Session(config=tf.ConfigProto( allow_soft_placement=True, log_device_placement=log_device)) # initialization init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer()) self.sess.run(init_op) self.coord = None self.threads = None FLAGS.print('Graph initialization finished...') # configuration self.ckpt_var_list = ckpt_var_list if do_save: self.saver = self._get_saver_() self.save_path = save_path else: self.saver = None self.save_path = None self.summary_writer = None self.do_trace = do_trace self.load_ckpt = load_ckpt
def run_once(self, var_list, ckpt_folder=None, ckpt_file=None, ckpt_var_list=None, feed_dict=None, do_time=False): """ This functions calculates var_list. :param var_list: :param ckpt_folder: :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :param ckpt_var_list: the variable to load in order to calculate var_list :param feed_dict: :param do_time: :return: """ if ckpt_var_list is not None: self.ckpt_var_list = ckpt_var_list self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file) self._check_thread_() if do_time: start_time = time.time() var_value = self.sess.run(var_list, feed_dict=feed_dict) FLAGS.print('Running session took {:.3f} sec.'.format(time.time() - start_time)) else: var_value = self.sess.run(var_list, feed_dict=feed_dict) return var_value
def train( self, op_list, loss_list, global_step, max_step=None, step_per_epoch=None, summary_op=None, summary_image_op=None, imbalanced_update=None, force_print=False, mog_model=None): """ This method do the optimization process to minimizes loss_list :param op_list: [net0_op, net1_op, net2_op] :param loss_list: [loss0, loss1, loss2] :param global_step: :param max_step: :param step_per_epoch: :param summary_op: :param summary_image_op: :param imbalanced_update: :param force_print: :return: """ # Check inputs if imbalanced_update is not None: self.imbalanced_update = imbalanced_update if self.imbalanced_update is not None: assert isinstance(self.imbalanced_update, (list, tuple, str, NetPicker)), \ 'Imbalanced_update must be a list, tuple or str or netpicker.' if self.debug is None: # sess = tf.Session(config=tf.ConfigProto( # allow_soft_placement=True, # log_device_placement=False)) writer = tf.summary.FileWriter(logdir=self.summary_folder, graph=tf.get_default_graph()) writer.flush() # graph_protobuf = str(tf.get_default_graph().as_default()) # with open(os.path.join(self.summary_folder, 'graph'), 'w') as f: # f.write(graph_protobuf) FLAGS.print('Graph printed.') elif self.debug is True: FLAGS.print('Debug mode is on.') FLAGS.print('Remember to load ckpt to check variable values.') with MySession(self.do_save, self.do_trace, self.save_path, self.load_ckpt, self.log_device) as sess: sess.debug_mode(op_list, loss_list, global_step, summary_op, self.summary_folder, self.ckpt_folder, max_step=self.debug_step, print_loss=self.print_loss, query_step=self.query_step, imbalanced_update=self.imbalanced_update) elif self.debug is False: # ---------------------------------------------------------------------- ALMOST THERE # print('-------------------------- starting session for full run') with MySession(self.do_save, self.do_trace, self.save_path, self.load_ckpt) as sess: sess.full_run(op_list, loss_list, max_step, step_per_epoch, global_step, summary_op, summary_image_op, self.summary_folder, self.ckpt_folder, print_loss=self.print_loss, query_step=self.query_step, imbalanced_update=self.imbalanced_update, force_print=force_print, mog_model=mog_model) else: raise AttributeError('Current debug mode is not supported.')
def _load_ckpt_(self, ckpt_folder=None, ckpt_file=None, force_print=False): """ This function loads a checkpoint model :param ckpt_folder: :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :param force_print: :return: """ if self.load_ckpt and (ckpt_folder is not None): ckpt = get_ckpt(ckpt_folder, ckpt_file=ckpt_file) if ckpt is None: FLAGS.print( 'No ckpt Model found at {}. Model training from scratch.'.format(ckpt_folder), force_print) else: if self.saver is None: self.saver = self._get_saver_() self.saver.restore(self.sess, ckpt.model_checkpoint_path) FLAGS.print('Model reloaded from {}.'.format(ckpt_folder), force_print) else: FLAGS.print('No ckpt model is loaded for current calculation.')
def scheduler(self, batch_size=None, num_epoch=None, shuffle_data=True, buffer_size=None, skip_count=None, sample_same_class=False, sample_class=None): """ This function schedules the batching process :param batch_size: :param num_epoch: :param buffer_size: :param skip_count: :param sample_same_class: if the data must be sampled from the same class at one iteration :param sample_class: if provided, the data will be sampled from class of this label, otherwise, data of a random class are sampled. :param shuffle_data: :return: """ if not self.scheduled: # update batch information if batch_size is not None: self.batch_size = batch_size self.batch_shape[0] = self.batch_size if num_epoch is not None: self.num_epoch = num_epoch if buffer_size is not None: self.buffer_size = buffer_size if skip_count is not None: self.skip_count = skip_count # skip instances if self.skip_count > 0: print('Number of {} instances skipped.'.format( self.skip_count)) self.dataset = self.dataset.skip(self.skip_count) # shuffle if shuffle_data: self.dataset = self.dataset.shuffle(self.buffer_size) # set batching process if sample_same_class: if sample_class is None: print('Caution: samples from the same class at each call.') group_fun = tf.contrib.data.group_by_window( key_func=lambda data_x, data_y: data_y, reduce_func=lambda key, d: d.batch(self.batch_size), window_size=self.batch_size) self.dataset = self.dataset.apply(group_fun) else: print( 'Caution: samples from class {}. This should not be used in training' .format(sample_class)) self.dataset = self.dataset.filter( lambda x, y: tf.equal(y[0], sample_class)) self.dataset = self.dataset.batch(self.batch_size) else: self.dataset = self.dataset.batch(self.batch_size) # self.dataset = self.dataset.padded_batch(batch_size) if self.num_epoch is None: self.dataset = self.dataset.repeat() else: FLAGS.print('Num_epoch set: {} epochs.'.format(num_epoch)) self.dataset = self.dataset.repeat(self.num_epoch) self.iterator = self.dataset.make_one_shot_iterator() self.scheduled = True
def eval_sampling(self, filename, sub_folder, mesh_num=None, mesh_mode=0, if_invert=False, code_x=None, code_y=None, real_sample=False, sample_same_class=False, get_dis_score=True, do_sprite=True, do_embedding=False, ckpt_file=None, num_threads=7): """ This function randomly generates samples and writes them to sprite. :param sample_same_class: :param code_y: :param filename: :param sub_folder: :param mesh_num: :param if_invert: :param mesh_mode: :param code_x: if provided, z_batch will be used to generate images. :param num_threads: :param real_sample: True if real sample should also be obtained :param get_dis_score: bool, whether to calculate the scores from the discriminator :param do_sprite: :param do_embedding: :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :return: """ # prepare folder ckpt_folder, summary_folder, _ = prepare_folder(filename, sub_folder=sub_folder) # check inputs if mesh_num is None: mesh_num = (10, 10) elif code_x is not None: assert code_x.shape[0] == mesh_num[0] * mesh_num[1] batch_size = mesh_num[0] * mesh_num[1] if do_embedding is True: get_dis_score = True real_sample = True # build the network graph self.graph = tf.Graph() with self.graph.as_default(): self.init_net() # get real sample if real_sample: self.sample_same_class = sample_same_class data_batch = self.get_data_batch(filename, batch_size, num_threads=num_threads) else: data_batch = {'x': tf.constant(0)} # sample validation instances if code_x is None: code = MeshCode(self.code_size, mesh_num=mesh_num) code_x = code.get_batch(mesh_mode, name='code_x') if code_y is None and self.sample_same_class and 'y' in data_batch: code_y = data_batch['y'] code_batch = self.sample_codes(batch_size, code_x, code_y, name='code_te') # generate new images gen_batch = self.__gpu_task__(code_batch=code_batch, is_training=False) # do clipping gen_batch['x'] = tf.clip_by_value(gen_batch['x'], clip_value_min=-1, clip_value_max=1) # get discriminator scores if get_dis_score and real_sample: dis_out = self.Dis(self.concat_two_batches( data_batch, gen_batch), is_training=False) s_x, s_gen = tf.split(dis_out['x'], num_or_size_splits=2, axis=0) else: s_x = tf.constant(0) s_gen = tf.constant(0) FLAGS.print('Graph configuration finished...') # calculate the value of x_gen var_list = [gen_batch['x'], data_batch['x'], s_x, s_gen] _temp, global_step_value = rollback(var_list, ckpt_folder, ckpt_file=ckpt_file) x_gen_value, x_real_value, s_x_value, s_gen_value = _temp # write to files if do_sprite: if real_sample: write_sprite_wrapper(x_real_value, mesh_num, filename, file_folder=summary_folder, file_index='_r_' + sub_folder + '_' + str(global_step_value) + '_' + str(mesh_mode), if_invert=if_invert, image_format=FLAGS.IMAGE_FORMAT) write_sprite_wrapper(x_gen_value, mesh_num, filename, file_folder=summary_folder, file_index='_g_' + sub_folder + '_' + str(global_step_value) + '_' + str(mesh_mode), if_invert=if_invert, image_format=FLAGS.IMAGE_FORMAT) # do visualization for code_value if do_embedding: # transpose image data if necessary if real_sample: x_as_image = np.transpose( x_real_value, axes=self.perm) if self.perm is not None else x_real_value x_gen_as_image = np.transpose( x_gen_value, axes=self.perm) if self.perm is not None else x_gen_value # concatenate real and generated images, codes and labels s_x_value = np.concatenate((s_x_value, s_gen_value), axis=0) x_as_image = np.concatenate((x_as_image, x_gen_as_image), axis=0) labels = np.concatenate( # 1 for real, 0 for gen (np.ones(batch_size, dtype=np.int), np.zeros(batch_size, dtype=np.int)), axis=0) # embedding mesh_num = (mesh_num[0] * 2, mesh_num[1]) embedding_image_wrapper(s_x_value, filename, var_name='x_vs_xg', file_folder=summary_folder, file_index='_x_vs_xg_' + sub_folder + '_' + str(global_step_value) + '_' + str(mesh_mode), labels=labels, images=x_as_image, mesh_num=mesh_num, if_invert=if_invert, image_format=FLAGS.IMAGE_FORMAT)
def training(self, filename, agent, num_instance, lr_list, end_lr=1e-7, max_step=None, batch_size=64, sample_same_class=False, num_threads=7, gpu='/gpu:0'): """ This function defines the training process :param filename: :param agent: :param num_instance: :param lr_list: :param end_lr: :param max_step: :type max_step: int :param batch_size: :param sample_same_class: bool, if at each iteration the data should be sampled from the same class :param num_threads: :param gpu: which gpu to use :return: """ self.step_per_epoch = np.floor(num_instance / batch_size).astype( np.int32) self.sample_same_class = sample_same_class if max_step >= self.step_per_epoch: from math import gcd file_repeat = int(batch_size / gcd(num_instance, batch_size)) \ if self.num_class < 2 else int(batch_size / gcd(int(num_instance / self.num_class), batch_size)) shuffle_file = False else: if isinstance(filename, str) or (isinstance(filename, (list, tuple)) and len(filename) == 1): raise AttributeError( 'max_step should be larger than step_per_epoch when there is a single file.' ) else: # for large dataset, the data are stored in multiple files. If all files cannot be visited # within max_step, consider shuffle the filename list every max_step file_repeat = 1 shuffle_file = True FLAGS.print( 'Num Instance: {}; Num Class: {}; Batch: {}; File_repeat: {}'. format(num_instance, self.num_class, batch_size, file_repeat)) # build the graph self.graph = tf.Graph() with self.graph.as_default(), tf.device(gpu): self.init_net() # get next batch data_batch = self.get_data_batch(filename, batch_size, file_repeat, num_threads, shuffle_file, 'data_tr') FLAGS.print('Shape of input batch: {}'.format( data_batch['x'].get_shape().as_list())) # setup training process # with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): self.global_step = global_step_config() _, opt_ops = multi_opt_config(lr_list, end_lr=end_lr, optimizer=self.optimizer, global_step=self.global_step) # assign tasks with tf.variable_scope(tf.get_variable_scope()): # calculate loss and gradients grads_list, loss_list = self.__gpu_task__( batch_size=batch_size, is_training=True, data_batch=data_batch, opt_op=opt_ops) # apply the gradient if agent.imbalanced_update is None: dis_op = opt_ops[0].apply_gradients( grads_list[0], global_step=self.global_step) gen_op = opt_ops[1].apply_gradients(grads_list[1]) op_list = [dis_op, gen_op] elif isinstance(agent.imbalanced_update, (list, tuple)): FLAGS.print( 'Imbalanced update used: dis per {} run and gen per {} run' .format(agent.imbalanced_update[0], agent.imbalanced_update[1])) if agent.imbalanced_update[0] == 1: dis_op = opt_ops[0].apply_gradients( grads_list[0], global_step=self.global_step) gen_op = opt_ops[1].apply_gradients(grads_list[1]) op_list = [dis_op, gen_op] elif agent.imbalanced_update[1] == 1: dis_op = opt_ops[0].apply_gradients(grads_list[0]) gen_op = opt_ops[1].apply_gradients( grads_list[1], global_step=self.global_step) op_list = [dis_op, gen_op] else: raise AttributeError( 'One of the imbalanced_update must be 1.') elif isinstance(agent.imbalanced_update, str): dis_op = opt_ops[0].apply_gradients(grads_list[0]) gen_op = opt_ops[1].apply_gradients( grads_list[1], global_step=self.global_step) op_list = [dis_op, gen_op] else: raise AttributeError('Imbalanced_update not identified.') # summary op is always pinned to CPU # add summary for all trainable variables if self.do_summary: for grads in grads_list: for var_grad, var in grads: var_name = var.name.replace(':', '_') tf.summary.histogram('grad_' + var_name, var_grad) tf.summary.histogram(var_name, var) summary_op = tf.summary.merge_all() else: summary_op = None # add summary for final image reconstruction if self.do_summary_image: tf.get_variable_scope().reuse_variables() summary_image_op = self.summary_image_sampling(data_batch) else: summary_image_op = None # run the session FLAGS.print('loss_list name: {}.'.format(self.loss_names)) agent.train(op_list, loss_list, self.global_step, max_step, self.step_per_epoch, summary_op, summary_image_op, force_print=self.force_print) self.force_print = False # force print at the first call
def debug_mode(self, op_list, loss_list, global_step, summary_op=None, summary_folder=None, ckpt_folder=None, ckpt_file=None, max_step=200, print_loss=True, query_step=100, imbalanced_update=None): """ This function do tracing to debug the code. It will burn-in for 25 steps, then record the usage every 5 steps for 5 times. :param op_list: :param loss_list: :param global_step: :param summary_op: :param summary_folder: :param max_step: :param ckpt_folder: :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :param print_loss: :param query_step: :param imbalanced_update: a list indicating the period to update each ops in op_list; the first op must have period = 1 as it updates the global step :return: """ if self.do_trace or (summary_op is not None): self.summary_writer = tf.compat.v1.summary.FileWriter(summary_folder, self.sess.graph) if self.do_trace: run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) run_metadata = tf.RunMetadata() multi_runs_timeline = TimeLiner() else: run_options = None run_metadata = None multi_runs_timeline = None if query_step > max_step: query_step = np.minimum(max_step-1, 100) # run the session self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file) self._check_thread_() extra_update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) # print(extra_update_ops) start_time = time.time() if imbalanced_update is None: for step in range(max_step): if self.do_trace and (step >= max_step - 5): # update the model in trace mode loss_value, _, global_step_value, _ = self.sess.run( [loss_list, op_list, global_step, extra_update_ops], options=run_options, run_metadata=run_metadata) # add time line self.summary_writer.add_run_metadata( run_metadata, tag='step%d' % global_step_value, global_step=global_step_value) trace = timeline.Timeline(step_stats=run_metadata.step_stats) chrome_trace = trace.generate_chrome_trace_format() multi_runs_timeline.update_timeline(chrome_trace) else: # update the model loss_value, _, global_step_value, _ = self.sess.run( [loss_list, op_list, global_step, extra_update_ops]) # print(loss_value) and add summary if global_step_value % query_step == 1: # at step 0, global step = 1 if print_loss: self.print_loss(loss_value, global_step_value) if summary_op is not None: summary_str = self.sess.run(summary_op) self.summary_writer.add_summary(summary_str, global_step=global_step_value) # in abnormal cases, save the model if self.abnormal_save(loss_value, global_step_value, summary_op): break # save the mdl if for loop completes normally if step == max_step - 1 and self.saver is not None: self.saver.save(self.sess, save_path=self.save_path, global_step=global_step_value) elif isinstance(imbalanced_update, (list, tuple)): num_ops = len(op_list) assert len(imbalanced_update) == num_ops, 'Imbalanced_update length does not match ' \ 'that of op_list. Expected {} got {}.'.format( num_ops, len(imbalanced_update)) for step in range(max_step): # get update ops global_step_value = self.sess.run(global_step) # added function to take care of added negative option # update_ops = [op_list[i] for i in range(num_ops) if global_step_value % imbalanced_update[i] == 0] update_ops = select_ops_to_update(op_list, global_step_value, imbalanced_update) loss_value = self.do_imbalanced_update(step, max_step, loss_list, update_ops, extra_update_ops, run_options, run_metadata, global_step_value, multi_runs_timeline) # print(loss_value) if print_loss and (step % query_step == 0): self.print_loss(loss_value, global_step_value) if self.summary_and_save(summary_op, global_step_value, loss_value, step, max_step) == 'break': break elif isinstance(imbalanced_update, str) and imbalanced_update == 'dynamic': # This case is used for sngan_mmd_rand_g only mmd_average = 0.0 for step in range(max_step): # get update ops global_step_value = self.sess.run(global_step) update_ops = op_list if \ global_step_value < 1000 or \ np.random.uniform(low=0.0, high=1.0) < 0.1 / np.maximum(mmd_average, 0.1) else \ op_list[1:] loss_value = self.do_imbalanced_update(step, max_step, loss_list, update_ops, extra_update_ops, run_options, run_metadata, global_step_value, multi_runs_timeline) # update mmd_average mmd_average = loss_value[2] # print(loss_value) if print_loss and (step % query_step == 0): self.print_loss(loss_value, global_step_value) if self.summary_and_save(summary_op, global_step_value, loss_value, step, max_step) == 'break': break # calculate sess duration duration = time.time() - start_time FLAGS.print('Training for {} steps took {:.3f} sec.'.format(max_step, duration)) # save tracing file if self.do_trace: trace_file = os.path.join(summary_folder, 'timeline.json') multi_runs_timeline.save(trace_file)
def full_run(self, op_list, loss_list, max_step, step_per_epoch, global_step, summary_op=None, summary_image_op=None, summary_folder=None, ckpt_folder=None, ckpt_file=None, print_loss=True, query_step=500, imbalanced_update=None, force_print=False, mog_model=None): """ This function run the session with all monitor functions. :param op_list: the first op in op_list runs every extra_steps when the rest run once. :param loss_list: the first loss is used to monitor the convergence :param max_step: :param step_per_epoch: :param global_step: :param summary_op: :param summary_image_op: :param summary_folder: :param ckpt_folder: :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :param print_loss: :param query_step: :param imbalanced_update: a list indicating the period to update each ops in op_list; the first op must have period = 1 as it updates the global step :param force_print: :param mog_model: :return: """ # prepare writer if (summary_op is not None) or (summary_image_op is not None): self.summary_writer = tf.compat.v1.summary.FileWriter(summary_folder, self.sess.graph) self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file, force_print=force_print) # run the session self._check_thread_() extra_update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS) start_time = time.time() if imbalanced_update is None: # ----------------------------------------------------- SIMULTANEOUS UPDATES HERE if mog_model is not None and mog_model.linked_gan.train_with_mog: mog_model.set_batch_encoding() global_step_value = None for step in range(max_step): # update MoG with current Dis params and current batch if mog_model is not None and mog_model.linked_gan.train_with_mog: mog_model.update_by_batch(self.sess) if mog_model.store_encodings: if global_step_value is None: # first iteration only mog_model.store_encodings_and_params(self.sess, summary_folder, 0) elif global_step_value % query_step == (query_step-1): mog_model.store_encodings_and_params(self.sess, summary_folder, global_step_value) if not isinstance(loss_list, list): # go from namedtuple to list loss_list = list(loss_list) # update the model loss_value, _, _, global_step_value = self.sess.run( [loss_list, op_list, extra_update_ops, global_step]) # check if model produces nan outcome assert not any(np.isnan(loss_value)), \ 'Model diverged with loss = {} at step {}'.format(loss_value, step) # maybe re-init mog after a few epochs, as it may have gotten lost given the rapid change of encodings if mog_model is not None and global_step_value == mog_model.re_init_at_step: mog_model.init_np_mog() # add summary and print loss every query step if global_step_value % query_step == (query_step-1) or global_step_value == 1: if mog_model is not None and mog_model.means_summary_op is not None and summary_op is not None: summary_str, summary_str_means = self.sess.run([summary_op, mog_model.means_summary_op]) self.summary_writer.add_summary(summary_str, global_step=global_step_value) self.summary_writer.add_summary(summary_str_means, global_step=global_step_value) elif summary_op is not None: summary_str = self.sess.run(summary_op) self.summary_writer.add_summary(summary_str, global_step=global_step_value) if print_loss: epoch = step // step_per_epoch self.print_loss(loss_value, global_step_value, epoch) # save model at last step if step == max_step - 1: self.save_model(global_step_value, summary_image_op) elif isinstance(imbalanced_update, (list, tuple, NetPicker)): # <-------------------- ALTERNATING TRAINING HERE for step in range(max_step): # <------------------------------------------------------ ACTUAL TRAINING LOOP # get update ops global_step_value = self.sess.run(global_step) if False and mog_model is not None and mog_model.linked_gan.train_with_mog: if mog_model.time_to_update(global_step_value, imbalanced_update): mog_model.update(self.sess) # IF STEP VALUE INDICATES TRAINING GENERATOR: # - collect all data encodings # - update MoG parameters # - proceed with training, sampling from updated MoG # in other places: # - predefine MoG distribution # - redefine generator loss through samples from the MoG update_ops = select_ops_to_update(op_list, global_step_value, imbalanced_update) # <------ OP SELECTION # update the model loss_value, _, _ = self.sess.run([loss_list, update_ops, extra_update_ops]) # <---------- WEIGHT UPDATE # check if model produces nan outcome assert not any(np.isnan(loss_value)), \ 'Model diverged with loss = {} at step {}'.format(loss_value, step) # add summary and print loss every query step if global_step_value % query_step == (query_step - 1): if summary_op is not None: summary_str = self.sess.run(summary_op) self.summary_writer.add_summary(summary_str, global_step=global_step_value) if print_loss: epoch = step // step_per_epoch self.print_loss(loss_value, global_step_value, epoch) # ------------------------------------------------------------ALSO TAKE MoG APPROXIMATION STATS HERE if False and mog_model is not None and not mog_model.linked_gan.train_with_mog: mog_model.test_mog_approx(self.sess) # save model at last step if step == max_step - 1: self.save_model(global_step_value, summary_image_op) elif imbalanced_update == 'dynamic': # This case is used for sngan_mmd_rand_g only mmd_average = 0.0 for step in range(max_step): # get update ops global_step_value = self.sess.run(global_step) update_ops = op_list if \ global_step_value < 1000 or \ np.random.uniform(low=0.0, high=1.0) < 0.1 / np.maximum(mmd_average, 0.1) else \ op_list[1:] # update the model loss_value, _, _, global_step_value = self.sess.run([loss_list, update_ops, extra_update_ops]) # check if model produces nan outcome assert not any(np.isnan(loss_value)), \ 'Model diverged with loss = {} at step {}'.format(loss_value, step) # add summary and print loss every query step if global_step_value % query_step == (query_step - 1): if summary_op is not None: summary_str = self.sess.run(summary_op) self.summary_writer.add_summary(summary_str, global_step=global_step_value) if print_loss: epoch = step // step_per_epoch self.print_loss(loss_value, global_step_value, epoch) # save model at last step if step == max_step - 1: self.save_model(global_step_value, summary_image_op) # calculate sess duration duration = time.time() - start_time FLAGS.print('Training for {} steps took {:.3f} sec.'.format(max_step, duration))
def opt_config(initial_lr, lr_decay_steps=None, end_lr=1e-7, optimizer='adam', name_suffix='', global_step=None, target_step=1e5): """ This function configures optimizer. :param initial_lr: :param lr_decay_steps: :param end_lr: :param optimizer: :param name_suffix: :param global_step: :param target_step: :return: """ if optimizer in ['SGD', 'sgd']: # sgd if lr_decay_steps is None: lr_decay_steps = np.round(target_step * np.log(0.96) / np.log(end_lr / initial_lr)).astype( np.int32) learning_rate = tf.train.exponential_decay( # adaptive learning rate initial_lr, global_step=global_step, decay_steps=lr_decay_steps, decay_rate=0.96, staircase=False) opt_op = tf.train.GradientDescentOptimizer(learning_rate, name='GradientDescent' + name_suffix) FLAGS.print('GradientDescent Optimizer is used.') elif optimizer in ['Momentum', 'momentum']: # momentum if lr_decay_steps is None: lr_decay_steps = np.round(target_step * np.log(0.96) / np.log(end_lr / initial_lr)).astype( np.int32) learning_rate = tf.train.exponential_decay( # adaptive learning rate initial_lr, global_step=global_step, decay_steps=lr_decay_steps, decay_rate=0.96, staircase=False) opt_op = tf.train.MomentumOptimizer(learning_rate, momentum=0.9, name='Momentum' + name_suffix) FLAGS.print('Momentum Optimizer is used.') elif optimizer in ['Adam', 'adam']: # adam # Occasionally, adam optimizer may cause the objective to become nan in the first few steps # This is because at initialization, the gradients may be very big. Setting beta1 and beta2 # close to 1 may prevent this. learning_rate = tf.constant(initial_lr) # opt_op = tf.train.AdamOptimizer( # learning_rate, beta1=0.9, beta2=0.99, epsilon=1e-8, name='Adam'+name_suffix) opt_op = tf.compat.v1.train.AdamOptimizer(learning_rate, beta1=0.5, beta2=0.999, epsilon=1e-8, name='Adam' + name_suffix) FLAGS.print('Adam Optimizer is used.') elif optimizer in ['RMSProp', 'rmsprop']: # RMSProp learning_rate = tf.constant(initial_lr) opt_op = tf.train.RMSPropOptimizer(learning_rate, decay=0.9, momentum=0.0, epsilon=1e-10, name='RMSProp' + name_suffix) FLAGS.print('RMSProp Optimizer is used.') else: raise AttributeError('Optimizer {} not supported.'.format(optimizer)) return learning_rate, opt_op
def inception_score_and_fid_v1(self, x_batch, y_batch, num_batch=10, ckpt_folder=None, ckpt_file=None): """ This function calculates inception scores and FID based on inception v1. Note: batch_size * num_batch needs to be larger than 2048, otherwise the convariance matrix will be ill-conditioned. According to TensorFlow v1.7 (below), this is actually inception v3 model. Somehow the downloaded file says it's v1. code link: https://github.com/tensorflow/tensorflow/blob/r1.7/tensorflow/contrib \ /gan/python/eval/python/classifier_metrics_impl.py Steps: 1, the pool3 and logits are calculated for x_batch and y_batch with sess 2, the pool3 and logits are passed to corresponding metrics :param ckpt_file: :param x_batch: tensor, one batch of x in range [-1, 1] :param y_batch: tensor, one batch of y in range [-1, 1] :param num_batch: :param ckpt_folder: check point folder :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284' :return: """ assert self.model == 'v1', 'GenerativeModelMetric is not initialized with model="v1".' assert ckpt_folder is not None, 'ckpt_folder must be provided.' x_logits, x_pool3 = self.inception_v1(x_batch) y_logits, y_pool3 = self.inception_v1(y_batch) with MySession(load_ckpt=True) as sess: inception_outputs = sess.run_m_times( [x_logits, y_logits, x_pool3, y_pool3], ckpt_folder=ckpt_folder, ckpt_file=ckpt_file, max_iter=num_batch, trace=True) # get logits and pool3 x_logits_np = np.concatenate([inc[0] for inc in inception_outputs], axis=0) y_logits_np = np.concatenate([inc[1] for inc in inception_outputs], axis=0) x_pool3_np = np.concatenate([inc[2] for inc in inception_outputs], axis=0) y_pool3_np = np.concatenate([inc[3] for inc in inception_outputs], axis=0) FLAGS.print('logits calculated. Shape = {}.'.format(x_logits_np.shape)) FLAGS.print('pool3 calculated. Shape = {}.'.format(x_pool3_np.shape)) # calculate scores inc_x = self.inception_score_from_logits(x_logits_np) inc_y = self.inception_score_from_logits(y_logits_np) xp3_1, xp3_2 = np.split(x_pool3_np, indices_or_sections=2, axis=0) fid_xx = self.fid_from_pool3(xp3_1, xp3_2) fid_xy = self.fid_from_pool3(x_pool3_np, y_pool3_np) with MySession() as sess: scores = sess.run_once([inc_x, inc_y, fid_xx, fid_xy]) return scores