コード例 #1
0
ファイル: train.py プロジェクト: sts-sadr/gan-2
def make_scaffold(stage_id, optimizer_var_list, **kwargs):
    """Makes a custom scaffold.

  The scaffold
  - restores variables from the last training stage.
  - initializes new variables in the new block.

  Args:
    stage_id: An integer of stage id.
    optimizer_var_list: A list of optimizer variables.
    **kwargs: A dictionary of
        'train_log_dir': A string of root directory of training logs.
        'num_resolutions': An integer of number of progressive resolutions.
        'stable_stage_num_images': An integer of number of training images in
          the stable stage.
        'transition_stage_num_images': An integer of number of training images
          in the transition stage.
        'total_num_images': An integer of total number of training images.

  Returns:
    A `Scaffold` object.
  """
    # Holds variables that from the previous stage and need to be restored.
    restore_var_list = []
    prev_ckpt = None
    curr_ckpt = tf.train.latest_checkpoint(
        make_train_sub_dir(stage_id, **kwargs))
    if stage_id > 0 and curr_ckpt is None:
        prev_ckpt = tf.train.latest_checkpoint(
            make_train_sub_dir(stage_id - 1, **kwargs))

        num_blocks, _ = get_stage_info(stage_id, **kwargs)
        prev_num_blocks, _ = get_stage_info(stage_id - 1, **kwargs)

        # Holds variables created in the new block of the current stage. If the
        # current stage is a stable stage (except the initial stage), this list
        # will be empty.
        new_block_var_list = []
        for block_id in range(prev_num_blocks + 1, num_blocks + 1):
            new_block_var_list.extend(
                tf.compat.v1.get_collection(
                    tf.compat.v1.GraphKeys.GLOBAL_VARIABLES,
                    scope='.*/{}/'.format(networks.block_name(block_id))))

        # Every variables that are 1) not for optimizers and 2) from the new block
        # need to be restored.
        restore_var_list = [
            var for var in tf.compat.v1.get_collection(
                tf.compat.v1.GraphKeys.GLOBAL_VARIABLES)
            if var not in set(optimizer_var_list + new_block_var_list)
        ]

    # Add saver op to graph. This saver is used to restore variables from the
    # previous stage.
    saver_for_restore = tf.compat.v1.train.Saver(var_list=restore_var_list,
                                                 allow_empty=True)
    # Add the op to graph that initializes all global variables.
    init_op = tf.compat.v1.global_variables_initializer()

    def _init_fn(unused_scaffold, sess):
        # First initialize every variables.
        sess.run(init_op)
        logging.info('\n'.join([var.name for var in restore_var_list]))
        # Then overwrite variables saved in previous stage.
        if prev_ckpt is not None:
            saver_for_restore.restore(sess, prev_ckpt)

    # Use a dummy init_op here as all initialization is done in init_fn.
    return tf.compat.v1.train.Scaffold(init_op=tf.constant([]),
                                       init_fn=_init_fn)
コード例 #2
0
ファイル: networks_test.py プロジェクト: sts-sadr/gan-2
 def test_block_name(self):
   self.assertEqual(networks.block_name(10), 'progressive_gan_block_10')