Beispiel #1
0
def rollback(var_list, ckpt_folder, ckpt_file=None):
    """ This function provides a shortcut for reloading a model and calculating a list of variables

    :param var_list:
    :param ckpt_folder:
    :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
    :return:
    """
    global_step = global_step_config()
    # register a session
    sess = tf.Session(config=tf.ConfigProto(allow_soft_placement=True,
                                            log_device_placement=False))
    # initialization
    init_op = tf.group(tf.global_variables_initializer(),
                       tf.local_variables_initializer())
    sess.run(init_op)
    # load the training graph
    saver = tf.compat.v1.train.Saver(max_to_keep=2)
    ckpt = get_ckpt(ckpt_folder, ckpt_file=ckpt_file)
    if ckpt is None:
        raise FileNotFoundError(
            'No ckpt Model found at {}.'.format(ckpt_folder))
    saver.restore(sess, ckpt.model_checkpoint_path)
    FLAGS.print('Model reloaded.')
    # run the session
    coord = tf.train.Coordinator()
    # threads = tf.train.start_queue_runners(sess=sess, coord=coord)
    var_value, global_step_value = sess.run([var_list, global_step])
    coord.request_stop()
    # coord.join(threads)
    sess.close()
    FLAGS.print('Variable calculated.')

    return var_value, global_step_value
Beispiel #2
0
 def print_loss(loss_value, step=0, epoch=0):
     FLAGS.print('Epoch {}, global steps {}, loss_list {}'.format(
         epoch, step,
         ['{}'.format(['<{:.2f}>'.format(l_val) for l_val in l_value])
          if isinstance(l_value, (np.ndarray, list))
          else '<{:.3f}>'.format(l_value)
          for l_value in loss_value]))
Beispiel #3
0
    def run_m_times(
            self, var_list, ckpt_folder=None, ckpt_file=None, max_iter=10000,
            trace=False, ckpt_var_list=None, feed_dict=None):
        """ This functions calculates var_list for multiple iterations, as often done in
        Monte Carlo analysis.

        :param var_list:
        :param ckpt_folder:
        :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
        :param max_iter:
        :param trace: if True, keep all outputs of m iterations
        :param ckpt_var_list: the variable to load in order to calculate var_list
        :param feed_dict:
        :return:
        """
        if ckpt_var_list is not None:
            self.ckpt_var_list = ckpt_var_list
        self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file)
        self._check_thread_()
        extra_update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        start_time = time.time()
        if trace:
            var_value_list = []
            for i in range(max_iter):
                var_value, _ = self.sess.run([var_list, extra_update_ops], feed_dict=feed_dict)
                var_value_list.append(var_value)
        else:
            for i in range(max_iter - 1):
                _, _ = self.sess.run([var_list, extra_update_ops], feed_dict=feed_dict)
            var_value_list, _ = self.sess.run([var_list, extra_update_ops], feed_dict=feed_dict)
        # global_step_value = self.sess.run([self.global_step])
        FLAGS.print('Calculation took {:.3f} sec.'.format(time.time() - start_time))
        return var_value_list
Beispiel #4
0
def write_sprite(sprite_path, images, mesh_num=None, if_invert=False):
    """ This function writes images to sprite image for embedding

    This function was taken from:
    https://github.com/oduerr/dl_tutorial/blob/master/tensorflow/debugging/embedding.ipynb

    The input image must be channels_last format.

    :param sprite_path: file name, e.g. '...\\a_sprite.png'
    :param images: ndarray, [batch_size, height, width(, channels)], values in range [0,1]
    :param if_invert: bool, if true, invert images: images = 1 - images
    :param mesh_num: nums of images in the row and column, must be a tuple
    :return:
    """
    if len(images.shape) == 3:  # if dimension of image is 3, extend it to 4
        images = np.tile(images[..., np.newaxis], (1, 1, 1, 3))
    if images.shape[3] == 1:  # if last dimension is 1, extend it to 3
        images = np.tile(images, (1, 1, 1, 3))
    # scale image to range [0,1]
    images = images.astype(np.float32)
    image_min = np.min(images.reshape((images.shape[0], -1)), axis=1)
    images = (images.transpose((1, 2, 3, 0)) - image_min).transpose(
        (3, 0, 1, 2))
    image_max = np.max(images.reshape((images.shape[0], -1)), axis=1)
    images = (images.transpose((1, 2, 3, 0)) / image_max).transpose(
        (3, 0, 1, 2))
    if if_invert:
        images = 1 - images
    # check mesh_num
    if mesh_num is None:
        FLAGS.print('Mesh_num will be calculated as sqrt of batch_size')
        batch_size = images.shape[0]
        sprite_size = int(np.ceil(np.sqrt(batch_size)))
        mesh_num = (sprite_size, sprite_size)
        # add paddings if batch_size is not the square of sprite_size
        padding = ((0, sprite_size**2 - batch_size), (0, 0),
                   (0, 0)) + ((0, 0), ) * (images.ndim - 3)
        images = np.pad(images, padding, mode='constant', constant_values=0)
    elif isinstance(mesh_num, list):
        mesh_num = tuple(mesh_num)
    # Tile the individual thumbnails into an image
    new_shape = mesh_num + images.shape[1:]
    images = images.reshape(new_shape).transpose(
        (0, 2, 1, 3) + tuple(range(4, images.ndim + 1)))
    images = images.reshape((mesh_num[0] * images.shape[1],
                             mesh_num[1] * images.shape[3]) + images.shape[4:])
    images = (images * 255).astype(np.uint8)
    # save images to file
    # from scipy.misc import imsave
    # imsave(sprite_path, images)
    try:
        from imageio import imwrite
        imwrite(sprite_path, images)
    except:
        print('attempt to write image failed!')
Beispiel #5
0
    def __exit__(self, exc_type, exc_val, exc_tb):
        """ The exit method is called when leaving the scope of "with" statement

        :param exc_type:
        :param exc_val:
        :param exc_tb:
        :return:
        """
        FLAGS.print('Session finished.')
        if self.summary_writer is not None:
            self.summary_writer.close()
        self.coord.request_stop()
        # self.coord.join(self.threads)
        self.sess.close()
Beispiel #6
0
    def __init__(
            self, do_save=False, do_trace=False, save_path=None,
            load_ckpt=False, log_device=False, ckpt_var_list=None):
        """ This class provides shortcuts for running sessions.
        It needs to be run under context managers. Example:
        with MySession() as sess:
            var1_value, var2_value = sess.run_once([var1, var2])

        :param do_save:
        :param do_trace:
        :param save_path:
        :param load_ckpt:
        :param log_device:
        :param ckpt_var_list: list of variables to save / restore
        """
        # somehow it gives error: "global_step does not exist or is not created from tf.get_variable".
        # self.global_step = global_step_config()
        self.log_device = log_device
        # register a session
        # self.sess = tf.Session(config=tf.ConfigProto(
        #     allow_soft_placement=True,
        #     log_device_placement=log_device,
        #     gpu_options=tf.GPUOptions(allow_growth=True)))
        self.sess = tf.Session(config=tf.ConfigProto(
            allow_soft_placement=True,
            log_device_placement=log_device))
        # initialization
        init_op = tf.group(tf.global_variables_initializer(), tf.local_variables_initializer())
        self.sess.run(init_op)
        self.coord = None
        self.threads = None
        FLAGS.print('Graph initialization finished...')
        # configuration
        self.ckpt_var_list = ckpt_var_list
        if do_save:
            self.saver = self._get_saver_()
            self.save_path = save_path
        else:
            self.saver = None
            self.save_path = None
        self.summary_writer = None
        self.do_trace = do_trace
        self.load_ckpt = load_ckpt
Beispiel #7
0
    def run_once(self, var_list, ckpt_folder=None, ckpt_file=None, ckpt_var_list=None, feed_dict=None, do_time=False):
        """ This functions calculates var_list.

        :param var_list:
        :param ckpt_folder:
        :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
        :param ckpt_var_list: the variable to load in order to calculate var_list
        :param feed_dict:
        :param do_time:
        :return:
        """
        if ckpt_var_list is not None:
            self.ckpt_var_list = ckpt_var_list
        self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file)
        self._check_thread_()

        if do_time:
            start_time = time.time()
            var_value = self.sess.run(var_list, feed_dict=feed_dict)
            FLAGS.print('Running session took {:.3f} sec.'.format(time.time() - start_time))
        else:
            var_value = self.sess.run(var_list, feed_dict=feed_dict)

        return var_value
Beispiel #8
0
    def train(
            self, op_list, loss_list, global_step,
            max_step=None, step_per_epoch=None,
            summary_op=None, summary_image_op=None, imbalanced_update=None, force_print=False,
            mog_model=None):
        """ This method do the optimization process to minimizes loss_list

        :param op_list: [net0_op, net1_op, net2_op]
        :param loss_list: [loss0, loss1, loss2]
        :param global_step:
        :param max_step:
        :param step_per_epoch:
        :param summary_op:
        :param summary_image_op:
        :param imbalanced_update:
        :param force_print:
        :return:
        """
        # Check inputs
        if imbalanced_update is not None:
            self.imbalanced_update = imbalanced_update
        if self.imbalanced_update is not None:
            assert isinstance(self.imbalanced_update, (list, tuple, str, NetPicker)), \
                'Imbalanced_update must be a list, tuple or str or netpicker.'

        if self.debug is None:
            # sess = tf.Session(config=tf.ConfigProto(
            #     allow_soft_placement=True,
            #     log_device_placement=False))
            writer = tf.summary.FileWriter(logdir=self.summary_folder, graph=tf.get_default_graph())
            writer.flush()
            # graph_protobuf = str(tf.get_default_graph().as_default())
            # with open(os.path.join(self.summary_folder, 'graph'), 'w') as f:
            #     f.write(graph_protobuf)
            FLAGS.print('Graph printed.')
        elif self.debug is True:
            FLAGS.print('Debug mode is on.')
            FLAGS.print('Remember to load ckpt to check variable values.')
            with MySession(self.do_save, self.do_trace, self.save_path, self.load_ckpt, self.log_device) as sess:
                sess.debug_mode(op_list, loss_list, global_step, summary_op, self.summary_folder, self.ckpt_folder,
                                max_step=self.debug_step, print_loss=self.print_loss, query_step=self.query_step,
                                imbalanced_update=self.imbalanced_update)
        elif self.debug is False:  # ---------------------------------------------------------------------- ALMOST THERE
            # print('-------------------------- starting session for full run')
            with MySession(self.do_save, self.do_trace, self.save_path, self.load_ckpt) as sess:
                sess.full_run(op_list, loss_list, max_step, step_per_epoch, global_step, summary_op, summary_image_op,
                              self.summary_folder, self.ckpt_folder, print_loss=self.print_loss,
                              query_step=self.query_step, imbalanced_update=self.imbalanced_update,
                              force_print=force_print, mog_model=mog_model)
        else:
            raise AttributeError('Current debug mode is not supported.')
Beispiel #9
0
    def _load_ckpt_(self, ckpt_folder=None, ckpt_file=None, force_print=False):
        """ This function loads a checkpoint model

        :param ckpt_folder:
        :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
        :param force_print:
        :return:
        """
        if self.load_ckpt and (ckpt_folder is not None):
            ckpt = get_ckpt(ckpt_folder, ckpt_file=ckpt_file)
            if ckpt is None:
                FLAGS.print(
                    'No ckpt Model found at {}. Model training from scratch.'.format(ckpt_folder), force_print)
            else:
                if self.saver is None:
                    self.saver = self._get_saver_()
                self.saver.restore(self.sess, ckpt.model_checkpoint_path)
                FLAGS.print('Model reloaded from {}.'.format(ckpt_folder), force_print)
        else:
            FLAGS.print('No ckpt model is loaded for current calculation.')
Beispiel #10
0
    def scheduler(self,
                  batch_size=None,
                  num_epoch=None,
                  shuffle_data=True,
                  buffer_size=None,
                  skip_count=None,
                  sample_same_class=False,
                  sample_class=None):
        """ This function schedules the batching process

        :param batch_size:
        :param num_epoch:
        :param buffer_size:
        :param skip_count:
        :param sample_same_class: if the data must be sampled from the same class at one iteration
        :param sample_class: if provided, the data will be sampled from class of this label, otherwise,
            data of a random class are sampled.
        :param shuffle_data:
        :return:
        """
        if not self.scheduled:
            # update batch information
            if batch_size is not None:
                self.batch_size = batch_size
                self.batch_shape[0] = self.batch_size
            if num_epoch is not None:
                self.num_epoch = num_epoch
            if buffer_size is not None:
                self.buffer_size = buffer_size
            if skip_count is not None:
                self.skip_count = skip_count
            # skip instances
            if self.skip_count > 0:
                print('Number of {} instances skipped.'.format(
                    self.skip_count))
                self.dataset = self.dataset.skip(self.skip_count)
            # shuffle
            if shuffle_data:
                self.dataset = self.dataset.shuffle(self.buffer_size)
            # set batching process
            if sample_same_class:
                if sample_class is None:
                    print('Caution: samples from the same class at each call.')
                    group_fun = tf.contrib.data.group_by_window(
                        key_func=lambda data_x, data_y: data_y,
                        reduce_func=lambda key, d: d.batch(self.batch_size),
                        window_size=self.batch_size)
                    self.dataset = self.dataset.apply(group_fun)
                else:
                    print(
                        'Caution: samples from class {}. This should not be used in training'
                        .format(sample_class))
                    self.dataset = self.dataset.filter(
                        lambda x, y: tf.equal(y[0], sample_class))
                    self.dataset = self.dataset.batch(self.batch_size)
            else:
                self.dataset = self.dataset.batch(self.batch_size)
            # self.dataset = self.dataset.padded_batch(batch_size)
            if self.num_epoch is None:
                self.dataset = self.dataset.repeat()
            else:
                FLAGS.print('Num_epoch set: {} epochs.'.format(num_epoch))
                self.dataset = self.dataset.repeat(self.num_epoch)

            self.iterator = self.dataset.make_one_shot_iterator()
            self.scheduled = True
Beispiel #11
0
    def eval_sampling(self,
                      filename,
                      sub_folder,
                      mesh_num=None,
                      mesh_mode=0,
                      if_invert=False,
                      code_x=None,
                      code_y=None,
                      real_sample=False,
                      sample_same_class=False,
                      get_dis_score=True,
                      do_sprite=True,
                      do_embedding=False,
                      ckpt_file=None,
                      num_threads=7):
        """ This function randomly generates samples and writes them to sprite.

        :param sample_same_class:
        :param code_y:
        :param filename:
        :param sub_folder:
        :param mesh_num:
        :param if_invert:
        :param mesh_mode:
        :param code_x: if provided, z_batch will be used to generate images.
        :param num_threads:
        :param real_sample: True if real sample should also be obtained
        :param get_dis_score: bool, whether to calculate the scores from the discriminator
        :param do_sprite:
        :param do_embedding:
        :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
        :return:
        """
        # prepare folder
        ckpt_folder, summary_folder, _ = prepare_folder(filename,
                                                        sub_folder=sub_folder)
        # check inputs
        if mesh_num is None:
            mesh_num = (10, 10)
        elif code_x is not None:
            assert code_x.shape[0] == mesh_num[0] * mesh_num[1]
        batch_size = mesh_num[0] * mesh_num[1]
        if do_embedding is True:
            get_dis_score = True
            real_sample = True

        # build the network graph
        self.graph = tf.Graph()
        with self.graph.as_default():
            self.init_net()

            # get real sample
            if real_sample:
                self.sample_same_class = sample_same_class
                data_batch = self.get_data_batch(filename,
                                                 batch_size,
                                                 num_threads=num_threads)
            else:
                data_batch = {'x': tf.constant(0)}

            # sample validation instances
            if code_x is None:
                code = MeshCode(self.code_size, mesh_num=mesh_num)
                code_x = code.get_batch(mesh_mode, name='code_x')
            if code_y is None and self.sample_same_class and 'y' in data_batch:
                code_y = data_batch['y']
            code_batch = self.sample_codes(batch_size,
                                           code_x,
                                           code_y,
                                           name='code_te')
            # generate new images
            gen_batch = self.__gpu_task__(code_batch=code_batch,
                                          is_training=False)
            # do clipping
            gen_batch['x'] = tf.clip_by_value(gen_batch['x'],
                                              clip_value_min=-1,
                                              clip_value_max=1)

            # get discriminator scores
            if get_dis_score and real_sample:
                dis_out = self.Dis(self.concat_two_batches(
                    data_batch, gen_batch),
                                   is_training=False)
                s_x, s_gen = tf.split(dis_out['x'],
                                      num_or_size_splits=2,
                                      axis=0)
            else:
                s_x = tf.constant(0)
                s_gen = tf.constant(0)

            FLAGS.print('Graph configuration finished...')
            # calculate the value of x_gen
            var_list = [gen_batch['x'], data_batch['x'], s_x, s_gen]
            _temp, global_step_value = rollback(var_list,
                                                ckpt_folder,
                                                ckpt_file=ckpt_file)
            x_gen_value, x_real_value, s_x_value, s_gen_value = _temp

        # write to files
        if do_sprite:
            if real_sample:
                write_sprite_wrapper(x_real_value,
                                     mesh_num,
                                     filename,
                                     file_folder=summary_folder,
                                     file_index='_r_' + sub_folder + '_' +
                                     str(global_step_value) + '_' +
                                     str(mesh_mode),
                                     if_invert=if_invert,
                                     image_format=FLAGS.IMAGE_FORMAT)
            write_sprite_wrapper(x_gen_value,
                                 mesh_num,
                                 filename,
                                 file_folder=summary_folder,
                                 file_index='_g_' + sub_folder + '_' +
                                 str(global_step_value) + '_' + str(mesh_mode),
                                 if_invert=if_invert,
                                 image_format=FLAGS.IMAGE_FORMAT)

        # do visualization for code_value
        if do_embedding:
            # transpose image data if necessary
            if real_sample:
                x_as_image = np.transpose(
                    x_real_value,
                    axes=self.perm) if self.perm is not None else x_real_value
                x_gen_as_image = np.transpose(
                    x_gen_value,
                    axes=self.perm) if self.perm is not None else x_gen_value
                # concatenate real and generated images, codes and labels
                s_x_value = np.concatenate((s_x_value, s_gen_value), axis=0)
                x_as_image = np.concatenate((x_as_image, x_gen_as_image),
                                            axis=0)
                labels = np.concatenate(  # 1 for real, 0 for gen
                    (np.ones(batch_size, dtype=np.int),
                     np.zeros(batch_size, dtype=np.int)),
                    axis=0)
                # embedding
                mesh_num = (mesh_num[0] * 2, mesh_num[1])
                embedding_image_wrapper(s_x_value,
                                        filename,
                                        var_name='x_vs_xg',
                                        file_folder=summary_folder,
                                        file_index='_x_vs_xg_' + sub_folder +
                                        '_' + str(global_step_value) + '_' +
                                        str(mesh_mode),
                                        labels=labels,
                                        images=x_as_image,
                                        mesh_num=mesh_num,
                                        if_invert=if_invert,
                                        image_format=FLAGS.IMAGE_FORMAT)
Beispiel #12
0
    def training(self,
                 filename,
                 agent,
                 num_instance,
                 lr_list,
                 end_lr=1e-7,
                 max_step=None,
                 batch_size=64,
                 sample_same_class=False,
                 num_threads=7,
                 gpu='/gpu:0'):
        """ This function defines the training process

        :param filename:
        :param agent:
        :param num_instance:
        :param lr_list:
        :param end_lr:
        :param max_step:
        :type max_step: int
        :param batch_size:
        :param sample_same_class: bool, if at each iteration the data should be sampled from the same class
        :param num_threads:
        :param gpu: which gpu to use
        :return:
        """
        self.step_per_epoch = np.floor(num_instance / batch_size).astype(
            np.int32)
        self.sample_same_class = sample_same_class
        if max_step >= self.step_per_epoch:
            from math import gcd
            file_repeat = int(batch_size / gcd(num_instance, batch_size)) \
                if self.num_class < 2 else int(batch_size / gcd(int(num_instance / self.num_class), batch_size))
            shuffle_file = False
        else:
            if isinstance(filename,
                          str) or (isinstance(filename, (list, tuple))
                                   and len(filename) == 1):
                raise AttributeError(
                    'max_step should be larger than step_per_epoch when there is a single file.'
                )
            else:
                # for large dataset, the data are stored in multiple files. If all files cannot be visited
                # within max_step, consider shuffle the filename list every max_step
                file_repeat = 1
                shuffle_file = True

        FLAGS.print(
            'Num Instance: {}; Num Class: {}; Batch: {}; File_repeat: {}'.
            format(num_instance, self.num_class, batch_size, file_repeat))

        # build the graph
        self.graph = tf.Graph()
        with self.graph.as_default(), tf.device(gpu):
            self.init_net()
            # get next batch
            data_batch = self.get_data_batch(filename, batch_size, file_repeat,
                                             num_threads, shuffle_file,
                                             'data_tr')
            FLAGS.print('Shape of input batch: {}'.format(
                data_batch['x'].get_shape().as_list()))

            # setup training process
            # with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE):
            self.global_step = global_step_config()
            _, opt_ops = multi_opt_config(lr_list,
                                          end_lr=end_lr,
                                          optimizer=self.optimizer,
                                          global_step=self.global_step)
            # assign tasks
            with tf.variable_scope(tf.get_variable_scope()):
                # calculate loss and gradients
                grads_list, loss_list = self.__gpu_task__(
                    batch_size=batch_size,
                    is_training=True,
                    data_batch=data_batch,
                    opt_op=opt_ops)
            # apply the gradient
            if agent.imbalanced_update is None:
                dis_op = opt_ops[0].apply_gradients(
                    grads_list[0], global_step=self.global_step)
                gen_op = opt_ops[1].apply_gradients(grads_list[1])
                op_list = [dis_op, gen_op]
            elif isinstance(agent.imbalanced_update, (list, tuple)):
                FLAGS.print(
                    'Imbalanced update used: dis per {} run and gen per {} run'
                    .format(agent.imbalanced_update[0],
                            agent.imbalanced_update[1]))
                if agent.imbalanced_update[0] == 1:
                    dis_op = opt_ops[0].apply_gradients(
                        grads_list[0], global_step=self.global_step)
                    gen_op = opt_ops[1].apply_gradients(grads_list[1])
                    op_list = [dis_op, gen_op]
                elif agent.imbalanced_update[1] == 1:
                    dis_op = opt_ops[0].apply_gradients(grads_list[0])
                    gen_op = opt_ops[1].apply_gradients(
                        grads_list[1], global_step=self.global_step)
                    op_list = [dis_op, gen_op]
                else:
                    raise AttributeError(
                        'One of the imbalanced_update must be 1.')
            elif isinstance(agent.imbalanced_update, str):
                dis_op = opt_ops[0].apply_gradients(grads_list[0])
                gen_op = opt_ops[1].apply_gradients(
                    grads_list[1], global_step=self.global_step)
                op_list = [dis_op, gen_op]
            else:
                raise AttributeError('Imbalanced_update not identified.')

            # summary op is always pinned to CPU
            # add summary for all trainable variables
            if self.do_summary:
                for grads in grads_list:
                    for var_grad, var in grads:
                        var_name = var.name.replace(':', '_')
                        tf.summary.histogram('grad_' + var_name, var_grad)
                        tf.summary.histogram(var_name, var)
                summary_op = tf.summary.merge_all()
            else:
                summary_op = None
            # add summary for final image reconstruction
            if self.do_summary_image:
                tf.get_variable_scope().reuse_variables()
                summary_image_op = self.summary_image_sampling(data_batch)
            else:
                summary_image_op = None

            # run the session
            FLAGS.print('loss_list name: {}.'.format(self.loss_names))
            agent.train(op_list,
                        loss_list,
                        self.global_step,
                        max_step,
                        self.step_per_epoch,
                        summary_op,
                        summary_image_op,
                        force_print=self.force_print)
            self.force_print = False  # force print at the first call
Beispiel #13
0
    def debug_mode(self, op_list, loss_list, global_step, summary_op=None, summary_folder=None, ckpt_folder=None,
                   ckpt_file=None, max_step=200, print_loss=True, query_step=100, imbalanced_update=None):
        """ This function do tracing to debug the code. It will burn-in for 25 steps, then record
        the usage every 5 steps for 5 times.

        :param op_list:
        :param loss_list:
        :param global_step:
        :param summary_op:
        :param summary_folder:
        :param max_step:
        :param ckpt_folder:
        :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
        :param print_loss:
        :param query_step:
        :param imbalanced_update: a list indicating the period to update each ops in op_list;
            the first op must have period = 1 as it updates the global step
        :return:
        """
        if self.do_trace or (summary_op is not None):
            self.summary_writer = tf.compat.v1.summary.FileWriter(summary_folder, self.sess.graph)
        if self.do_trace:
            run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
            run_metadata = tf.RunMetadata()
            multi_runs_timeline = TimeLiner()
        else:
            run_options = None
            run_metadata = None
            multi_runs_timeline = None
        if query_step > max_step:
            query_step = np.minimum(max_step-1, 100)

        # run the session
        self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file)
        self._check_thread_()
        extra_update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        # print(extra_update_ops)
        start_time = time.time()
        if imbalanced_update is None:
            for step in range(max_step):
                if self.do_trace and (step >= max_step - 5):
                    # update the model in trace mode
                    loss_value, _, global_step_value, _ = self.sess.run(
                        [loss_list, op_list, global_step, extra_update_ops],
                        options=run_options, run_metadata=run_metadata)
                    # add time line
                    self.summary_writer.add_run_metadata(
                        run_metadata, tag='step%d' % global_step_value, global_step=global_step_value)
                    trace = timeline.Timeline(step_stats=run_metadata.step_stats)
                    chrome_trace = trace.generate_chrome_trace_format()
                    multi_runs_timeline.update_timeline(chrome_trace)
                else:
                    # update the model
                    loss_value, _, global_step_value, _ = self.sess.run(
                        [loss_list, op_list, global_step, extra_update_ops])

                # print(loss_value) and add summary
                if global_step_value % query_step == 1:  # at step 0, global step = 1
                    if print_loss:
                        self.print_loss(loss_value, global_step_value)
                    if summary_op is not None:
                        summary_str = self.sess.run(summary_op)
                        self.summary_writer.add_summary(summary_str, global_step=global_step_value)

                # in abnormal cases, save the model
                if self.abnormal_save(loss_value, global_step_value, summary_op):
                    break

                # save the mdl if for loop completes normally
                if step == max_step - 1 and self.saver is not None:
                    self.saver.save(self.sess, save_path=self.save_path, global_step=global_step_value)

        elif isinstance(imbalanced_update, (list, tuple)):
            num_ops = len(op_list)
            assert len(imbalanced_update) == num_ops, 'Imbalanced_update length does not match ' \
                                                      'that of op_list. Expected {} got {}.'.format(
                num_ops, len(imbalanced_update))
            for step in range(max_step):
                # get update ops
                global_step_value = self.sess.run(global_step)
                # added function to take care of added negative option
                # update_ops = [op_list[i] for i in range(num_ops) if global_step_value % imbalanced_update[i] == 0]
                update_ops = select_ops_to_update(op_list, global_step_value, imbalanced_update)

                loss_value = self.do_imbalanced_update(step, max_step, loss_list, update_ops, extra_update_ops,
                                                       run_options, run_metadata, global_step_value,
                                                       multi_runs_timeline)

                # print(loss_value)
                if print_loss and (step % query_step == 0):
                    self.print_loss(loss_value, global_step_value)

                if self.summary_and_save(summary_op, global_step_value, loss_value, step, max_step) == 'break':
                    break

        elif isinstance(imbalanced_update, str) and imbalanced_update == 'dynamic':
            # This case is used for sngan_mmd_rand_g only
            mmd_average = 0.0
            for step in range(max_step):
                # get update ops
                global_step_value = self.sess.run(global_step)
                update_ops = op_list if \
                    global_step_value < 1000 or \
                    np.random.uniform(low=0.0, high=1.0) < 0.1 / np.maximum(mmd_average, 0.1) else \
                    op_list[1:]

                loss_value = self.do_imbalanced_update(step, max_step, loss_list, update_ops, extra_update_ops,
                                                       run_options, run_metadata, global_step_value,
                                                       multi_runs_timeline)

                # update mmd_average
                mmd_average = loss_value[2]

                # print(loss_value)
                if print_loss and (step % query_step == 0):
                    self.print_loss(loss_value, global_step_value)

                if self.summary_and_save(summary_op, global_step_value, loss_value, step, max_step) == 'break':
                    break

        # calculate sess duration
        duration = time.time() - start_time
        FLAGS.print('Training for {} steps took {:.3f} sec.'.format(max_step, duration))
        # save tracing file
        if self.do_trace:
            trace_file = os.path.join(summary_folder, 'timeline.json')
            multi_runs_timeline.save(trace_file)
Beispiel #14
0
    def full_run(self, op_list, loss_list, max_step, step_per_epoch, global_step, summary_op=None,
                 summary_image_op=None, summary_folder=None, ckpt_folder=None, ckpt_file=None, print_loss=True,
                 query_step=500, imbalanced_update=None, force_print=False,
                 mog_model=None):
        """ This function run the session with all monitor functions.

        :param op_list: the first op in op_list runs every extra_steps when the rest run once.
        :param loss_list: the first loss is used to monitor the convergence
        :param max_step:
        :param step_per_epoch:
        :param global_step:
        :param summary_op:
        :param summary_image_op:
        :param summary_folder:
        :param ckpt_folder:
        :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
        :param print_loss:
        :param query_step:
        :param imbalanced_update: a list indicating the period to update each ops in op_list;
            the first op must have period = 1 as it updates the global step
        :param force_print:
        :param mog_model:
        :return:
        """
        # prepare writer
        if (summary_op is not None) or (summary_image_op is not None):
            self.summary_writer = tf.compat.v1.summary.FileWriter(summary_folder, self.sess.graph)
        self._load_ckpt_(ckpt_folder, ckpt_file=ckpt_file, force_print=force_print)
        # run the session
        self._check_thread_()
        extra_update_ops = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.UPDATE_OPS)
        start_time = time.time()
        if imbalanced_update is None:  # ----------------------------------------------------- SIMULTANEOUS UPDATES HERE
            if mog_model is not None and mog_model.linked_gan.train_with_mog:
                mog_model.set_batch_encoding()

            global_step_value = None
            for step in range(max_step):

                # update MoG with current Dis params and current batch
                if mog_model is not None and mog_model.linked_gan.train_with_mog:
                    mog_model.update_by_batch(self.sess)

                    if mog_model.store_encodings:
                        if global_step_value is None:  # first iteration only
                            mog_model.store_encodings_and_params(self.sess, summary_folder, 0)

                        elif global_step_value % query_step == (query_step-1):
                            mog_model.store_encodings_and_params(self.sess, summary_folder, global_step_value)

                if not isinstance(loss_list, list):  # go from namedtuple to list
                    loss_list = list(loss_list)
                # update the model
                loss_value, _, _, global_step_value = self.sess.run(
                    [loss_list, op_list, extra_update_ops, global_step])
                # check if model produces nan outcome
                assert not any(np.isnan(loss_value)), \
                    'Model diverged with loss = {} at step {}'.format(loss_value, step)

                # maybe re-init mog after a few epochs, as it may have gotten lost given the rapid change of encodings
                if mog_model is not None and global_step_value == mog_model.re_init_at_step:
                    mog_model.init_np_mog()

                # add summary and print loss every query step
                if global_step_value % query_step == (query_step-1) or global_step_value == 1:
                    if mog_model is not None and mog_model.means_summary_op is not None and summary_op is not None:
                        summary_str, summary_str_means = self.sess.run([summary_op, mog_model.means_summary_op])
                        self.summary_writer.add_summary(summary_str, global_step=global_step_value)
                        self.summary_writer.add_summary(summary_str_means, global_step=global_step_value)
                    elif summary_op is not None:
                        summary_str = self.sess.run(summary_op)
                        self.summary_writer.add_summary(summary_str, global_step=global_step_value)
                    if print_loss:
                        epoch = step // step_per_epoch
                        self.print_loss(loss_value, global_step_value, epoch)

                # save model at last step
                if step == max_step - 1:
                    self.save_model(global_step_value, summary_image_op)

        elif isinstance(imbalanced_update, (list, tuple, NetPicker)):  # <-------------------- ALTERNATING TRAINING HERE

            for step in range(max_step):  # <------------------------------------------------------ ACTUAL TRAINING LOOP
                # get update ops
                global_step_value = self.sess.run(global_step)

                if False and mog_model is not None and mog_model.linked_gan.train_with_mog:
                    if mog_model.time_to_update(global_step_value, imbalanced_update):
                        mog_model.update(self.sess)
                # IF STEP VALUE INDICATES TRAINING GENERATOR:
                # - collect all data encodings
                # - update MoG parameters
                # - proceed with training, sampling from updated MoG

                # in other places:
                # - predefine MoG distribution
                # - redefine generator loss through samples from the MoG

                update_ops = select_ops_to_update(op_list, global_step_value, imbalanced_update)  # <------ OP SELECTION

                # update the model
                loss_value, _, _ = self.sess.run([loss_list, update_ops, extra_update_ops])  # <---------- WEIGHT UPDATE
                # check if model produces nan outcome
                assert not any(np.isnan(loss_value)), \
                    'Model diverged with loss = {} at step {}'.format(loss_value, step)

                # add summary and print loss every query step
                if global_step_value % query_step == (query_step - 1):
                    if summary_op is not None:
                        summary_str = self.sess.run(summary_op)
                        self.summary_writer.add_summary(summary_str, global_step=global_step_value)
                    if print_loss:
                        epoch = step // step_per_epoch
                        self.print_loss(loss_value, global_step_value, epoch)

                    # ------------------------------------------------------------ALSO TAKE MoG APPROXIMATION STATS HERE
                    if False and mog_model is not None and not mog_model.linked_gan.train_with_mog:
                        mog_model.test_mog_approx(self.sess)

                # save model at last step
                if step == max_step - 1:
                    self.save_model(global_step_value, summary_image_op)

        elif imbalanced_update == 'dynamic':
            # This case is used for sngan_mmd_rand_g only
            mmd_average = 0.0
            for step in range(max_step):
                # get update ops
                global_step_value = self.sess.run(global_step)
                update_ops = op_list if \
                    global_step_value < 1000 or \
                    np.random.uniform(low=0.0, high=1.0) < 0.1 / np.maximum(mmd_average, 0.1) else \
                    op_list[1:]

                # update the model
                loss_value, _, _, global_step_value = self.sess.run([loss_list, update_ops, extra_update_ops])
                # check if model produces nan outcome
                assert not any(np.isnan(loss_value)), \
                    'Model diverged with loss = {} at step {}'.format(loss_value, step)

                # add summary and print loss every query step
                if global_step_value % query_step == (query_step - 1):
                    if summary_op is not None:
                        summary_str = self.sess.run(summary_op)
                        self.summary_writer.add_summary(summary_str, global_step=global_step_value)
                    if print_loss:
                        epoch = step // step_per_epoch
                        self.print_loss(loss_value, global_step_value, epoch)

                # save model at last step
                if step == max_step - 1:
                    self.save_model(global_step_value, summary_image_op)

        # calculate sess duration
        duration = time.time() - start_time
        FLAGS.print('Training for {} steps took {:.3f} sec.'.format(max_step, duration))
Beispiel #15
0
def opt_config(initial_lr,
               lr_decay_steps=None,
               end_lr=1e-7,
               optimizer='adam',
               name_suffix='',
               global_step=None,
               target_step=1e5):
    """ This function configures optimizer.

    :param initial_lr:
    :param lr_decay_steps:
    :param end_lr:
    :param optimizer:
    :param name_suffix:
    :param global_step:
    :param target_step:
    :return:
    """

    if optimizer in ['SGD', 'sgd']:
        # sgd
        if lr_decay_steps is None:
            lr_decay_steps = np.round(target_step * np.log(0.96) /
                                      np.log(end_lr / initial_lr)).astype(
                                          np.int32)
        learning_rate = tf.train.exponential_decay(  # adaptive learning rate
            initial_lr,
            global_step=global_step,
            decay_steps=lr_decay_steps,
            decay_rate=0.96,
            staircase=False)
        opt_op = tf.train.GradientDescentOptimizer(learning_rate,
                                                   name='GradientDescent' +
                                                   name_suffix)
        FLAGS.print('GradientDescent Optimizer is used.')
    elif optimizer in ['Momentum', 'momentum']:
        # momentum
        if lr_decay_steps is None:
            lr_decay_steps = np.round(target_step * np.log(0.96) /
                                      np.log(end_lr / initial_lr)).astype(
                                          np.int32)
        learning_rate = tf.train.exponential_decay(  # adaptive learning rate
            initial_lr,
            global_step=global_step,
            decay_steps=lr_decay_steps,
            decay_rate=0.96,
            staircase=False)
        opt_op = tf.train.MomentumOptimizer(learning_rate,
                                            momentum=0.9,
                                            name='Momentum' + name_suffix)
        FLAGS.print('Momentum Optimizer is used.')
    elif optimizer in ['Adam', 'adam']:  # adam
        # Occasionally, adam optimizer may cause the objective to become nan in the first few steps
        # This is because at initialization, the gradients may be very big. Setting beta1 and beta2
        # close to 1 may prevent this.
        learning_rate = tf.constant(initial_lr)
        # opt_op = tf.train.AdamOptimizer(
        #     learning_rate, beta1=0.9, beta2=0.99, epsilon=1e-8, name='Adam'+name_suffix)
        opt_op = tf.compat.v1.train.AdamOptimizer(learning_rate,
                                                  beta1=0.5,
                                                  beta2=0.999,
                                                  epsilon=1e-8,
                                                  name='Adam' + name_suffix)
        FLAGS.print('Adam Optimizer is used.')

    elif optimizer in ['RMSProp', 'rmsprop']:
        # RMSProp
        learning_rate = tf.constant(initial_lr)
        opt_op = tf.train.RMSPropOptimizer(learning_rate,
                                           decay=0.9,
                                           momentum=0.0,
                                           epsilon=1e-10,
                                           name='RMSProp' + name_suffix)
        FLAGS.print('RMSProp Optimizer is used.')
    else:
        raise AttributeError('Optimizer {} not supported.'.format(optimizer))

    return learning_rate, opt_op
    def inception_score_and_fid_v1(self,
                                   x_batch,
                                   y_batch,
                                   num_batch=10,
                                   ckpt_folder=None,
                                   ckpt_file=None):
        """ This function calculates inception scores and FID based on inception v1.
        Note: batch_size * num_batch needs to be larger than 2048, otherwise the convariance matrix will be
        ill-conditioned.

        According to TensorFlow v1.7 (below), this is actually inception v3 model.
        Somehow the downloaded file says it's v1.
        code link: https://github.com/tensorflow/tensorflow/blob/r1.7/tensorflow/contrib \
        /gan/python/eval/python/classifier_metrics_impl.py

        Steps:
        1, the pool3 and logits are calculated for x_batch and y_batch with sess
        2, the pool3 and logits are passed to corresponding metrics

        :param ckpt_file:
        :param x_batch: tensor, one batch of x in range [-1, 1]
        :param y_batch: tensor, one batch of y in range [-1, 1]
        :param num_batch:
        :param ckpt_folder: check point folder
        :param ckpt_file: in case an older ckpt file is needed, provide it here, e.g. 'cifar.ckpt-6284'
        :return:
        """
        assert self.model == 'v1', 'GenerativeModelMetric is not initialized with model="v1".'
        assert ckpt_folder is not None, 'ckpt_folder must be provided.'

        x_logits, x_pool3 = self.inception_v1(x_batch)
        y_logits, y_pool3 = self.inception_v1(y_batch)

        with MySession(load_ckpt=True) as sess:
            inception_outputs = sess.run_m_times(
                [x_logits, y_logits, x_pool3, y_pool3],
                ckpt_folder=ckpt_folder,
                ckpt_file=ckpt_file,
                max_iter=num_batch,
                trace=True)

        # get logits and pool3
        x_logits_np = np.concatenate([inc[0] for inc in inception_outputs],
                                     axis=0)
        y_logits_np = np.concatenate([inc[1] for inc in inception_outputs],
                                     axis=0)
        x_pool3_np = np.concatenate([inc[2] for inc in inception_outputs],
                                    axis=0)
        y_pool3_np = np.concatenate([inc[3] for inc in inception_outputs],
                                    axis=0)
        FLAGS.print('logits calculated. Shape = {}.'.format(x_logits_np.shape))
        FLAGS.print('pool3 calculated. Shape = {}.'.format(x_pool3_np.shape))
        # calculate scores
        inc_x = self.inception_score_from_logits(x_logits_np)
        inc_y = self.inception_score_from_logits(y_logits_np)
        xp3_1, xp3_2 = np.split(x_pool3_np, indices_or_sections=2, axis=0)
        fid_xx = self.fid_from_pool3(xp3_1, xp3_2)
        fid_xy = self.fid_from_pool3(x_pool3_np, y_pool3_np)

        with MySession() as sess:
            scores = sess.run_once([inc_x, inc_y, fid_xx, fid_xy])

        return scores