def make_scaffold(stage_id, optimizer_var_list, **kwargs): """Makes a custom scaffold. The scaffold - restores variables from the last training stage. - initializes new variables in the new block. Args: stage_id: An integer of stage id. optimizer_var_list: A list of optimizer variables. **kwargs: A dictionary of 'train_log_dir': A string of root directory of training logs. 'num_resolutions': An integer of number of progressive resolutions. 'stable_stage_num_images': An integer of number of training images in the stable stage. 'transition_stage_num_images': An integer of number of training images in the transition stage. 'total_num_images': An integer of total number of training images. Returns: A `Scaffold` object. """ # Holds variables that from the previous stage and need to be restored. restore_var_list = [] prev_ckpt = None curr_ckpt = tf.train.latest_checkpoint( make_train_sub_dir(stage_id, **kwargs)) if stage_id > 0 and curr_ckpt is None: prev_ckpt = tf.train.latest_checkpoint( make_train_sub_dir(stage_id - 1, **kwargs)) num_blocks, _ = get_stage_info(stage_id, **kwargs) prev_num_blocks, _ = get_stage_info(stage_id - 1, **kwargs) # Holds variables created in the new block of the current stage. If the # current stage is a stable stage (except the initial stage), this list # will be empty. new_block_var_list = [] for block_id in range(prev_num_blocks + 1, num_blocks + 1): new_block_var_list.extend( tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='.*/{}/'.format(networks.block_name(block_id)))) # Every variables that are 1) not for optimizers and 2) from the new block # need to be restored. restore_var_list = [ var for var in tf.compat.v1.get_collection( tf.compat.v1.GraphKeys.GLOBAL_VARIABLES) if var not in set(optimizer_var_list + new_block_var_list) ] # Add saver op to graph. This saver is used to restore variables from the # previous stage. saver_for_restore = tf.compat.v1.train.Saver(var_list=restore_var_list, allow_empty=True) # Add the op to graph that initializes all global variables. init_op = tf.compat.v1.global_variables_initializer() def _init_fn(unused_scaffold, sess): # First initialize every variables. sess.run(init_op) logging.info('\n'.join([var.name for var in restore_var_list])) # Then overwrite variables saved in previous stage. if prev_ckpt is not None: saver_for_restore.restore(sess, prev_ckpt) # Use a dummy init_op here as all initialization is done in init_fn. return tf.compat.v1.train.Scaffold(init_op=tf.constant([]), init_fn=_init_fn)
def test_block_name(self): self.assertEqual(networks.block_name(10), 'progressive_gan_block_10')