Ejemplo n.º 1
0
 def __init__(self,
              dtype=DEFAULT_DTYPE,
              use_xla=DEFAULT_USE_XLA,
              data_format=DEFAULT_DATA_FORMAT,
              name=None):
     super(Bias, self).__init__(dtype=dtype, name=name)
     validate_data_format(data_format)
     self.data_format = data_format
     self.use_xla = use_xla
Ejemplo n.º 2
0
 def __init__(self,
              group_size,
              dtype=DEFAULT_DTYPE,
              use_xla=DEFAULT_USE_XLA,
              data_format=DEFAULT_DATA_FORMAT,
              name=None):
     super(Minibatch_StdDev, self).__init__(dtype=dtype, name=name)
     validate_data_format(data_format)
     self.data_format = data_format
     self.group_size = group_size
     self.use_xla = use_xla
Ejemplo n.º 3
0
def load_and_preprocess_image(file_path, res,
                              mirror_augment=DEFAULT_MIRROR_AUGMENT,
                              dtype=DEFAULT_DTYPE,
                              data_format=DEFAULT_DATA_FORMAT):
    validate_data_format(data_format)
    image = load_image(file_path, dtype)
    image = preprocess_image(image, res, mirror_augment, dtype)
    image = normalize_images(image)
    if data_format == NCHW_FORMAT:
        image = tf.transpose(image, [2, 0, 1])
    return image
Ejemplo n.º 4
0
def convert_outputs_to_images(net_outputs, target_single_image_size, data_format=DEFAULT_DATA_FORMAT):
    # Note: should work for linear and tanh activation
    validate_data_format(data_format)
    x = restore_images(net_outputs)
    if data_format == NCHW_FORMAT:
        x = tf.transpose(x, [0, 2, 3, 1])
    x = tf.image.resize(
        x,
        size=(target_single_image_size, target_single_image_size),
        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    return x
Ejemplo n.º 5
0
 def __init__(self,
              dtype=DEFAULT_DTYPE,
              use_xla=DEFAULT_USE_XLA,
              data_format=DEFAULT_DATA_FORMAT,
              name=None):
     super(Pixel_Norm, self).__init__(dtype=dtype, name=name)
     validate_data_format(data_format)
     self.data_format = data_format
     if self.data_format == NCHW_FORMAT:
         self.channel_axis = 1
     elif self.data_format == NHWC_FORMAT:
         self.channel_axis = 3
     self.epsilon = 1e-8 if self._dtype_policy.compute_dtype == 'float32' else 1e-4
     self.use_xla = use_xla
Ejemplo n.º 6
0
 def __init__(self,
              factor,
              dtype=DEFAULT_DTYPE,
              use_xla=DEFAULT_USE_XLA,
              data_format=DEFAULT_DATA_FORMAT,
              name=None):
     super(Downscale2d, self).__init__(dtype=dtype, name=name)
     validate_data_format(data_format)
     self.data_format = data_format
     if self.data_format == NCHW_FORMAT:
         self.ksize = [1, 1, factor, factor]
     elif self.data_format == NHWC_FORMAT:
         self.ksize = [1, factor, factor, 1]
     self.factor = factor
     self.use_xla = use_xla
Ejemplo n.º 7
0
 def __init__(self,
              units,
              gain=HE_GAIN,
              use_wscale=True,
              truncate_weights=DEFAULT_TRUNCATE_WEIGHTS,
              dtype=DEFAULT_DTYPE,
              use_xla=DEFAULT_USE_XLA,
              data_format=DEFAULT_DATA_FORMAT,
              name=None):
     super(Scaled_Linear, self).__init__(dtype=dtype, name=name)
     validate_data_format(data_format)
     self.data_format = data_format
     self.units = units
     self.gain = gain
     self.use_wscale = use_wscale
     self.truncate_weights = truncate_weights
     self.use_xla = use_xla
Ejemplo n.º 8
0
 def __init__(self,
              fmaps,
              kernel_size,
              gain=HE_GAIN,
              use_wscale=True,
              truncate_weights=DEFAULT_TRUNCATE_WEIGHTS,
              use_xla=DEFAULT_USE_XLA,
              dtype=DEFAULT_DTYPE,
              data_format=DEFAULT_DATA_FORMAT,
              name=None):
     super(Fused_Scaled_Conv2d_Downscale2d, self).__init__(dtype=dtype,
                                                           name=name)
     assert kernel_size >= 1 and kernel_size % 2 == 1
     validate_data_format(data_format)
     self.data_format = data_format
     self.fmaps = fmaps
     self.kernel_size = kernel_size
     self.gain = gain
     self.use_wscale = use_wscale
     self.truncate_weights = truncate_weights
     self.use_xla = use_xla
Ejemplo n.º 9
0
 def __init__(self,
              fmaps,
              kernel_size,
              stride=1,
              gain=HE_GAIN,
              use_wscale=True,
              truncate_weights=DEFAULT_TRUNCATE_WEIGHTS,
              dtype=DEFAULT_DTYPE,
              use_xla=DEFAULT_USE_XLA,
              data_format=DEFAULT_DATA_FORMAT,
              name=None):
     super(Scaled_Conv2d, self).__init__(dtype=dtype, name=name)
     validate_data_format(data_format)
     self.data_format = data_format
     self.fmaps = fmaps
     self.kernel_size = kernel_size
     self.stride = stride
     self.gain = gain
     self.use_wscale = use_wscale
     self.truncate_weights = truncate_weights
     self.use_xla = use_xla
    def __init__(self, config):

        self.target_resolution = config[TARGET_RESOLUTION]
        self.resolution_log2 = int(np.log2(self.target_resolution))
        assert self.target_resolution == 2**self.resolution_log2 and self.target_resolution >= 4

        self.start_resolution = config.get(START_RESOLUTION,
                                           DEFAULT_START_RESOLUTION)
        self.start_resolution_log2 = int(np.log2(self.start_resolution))
        assert self.start_resolution == 2**self.start_resolution_log2 and self.start_resolution >= 4

        self.min_lod = level_of_details(self.resolution_log2,
                                        self.resolution_log2)
        self.max_lod = level_of_details(self.start_resolution_log2,
                                        self.resolution_log2)

        self.data_format = config.get(DATA_FORMAT, DEFAULT_DATA_FORMAT)
        validate_data_format(self.data_format)

        self.latent_size = config[LATENT_SIZE]
        self.normalize_latents = config.get(NORMALIZE_LATENTS,
                                            DEFAULT_NORMALIZE_LATENTS)
        self.use_bias = config.get(USE_BIAS, DEFAULT_USE_BIAS)
        self.use_wscale = config.get(USE_WSCALE, DEFAULT_USE_WSCALE)
        self.use_pixelnorm = config.get(USE_PIXELNORM, DEFAULT_USE_PIXELNORM)
        self.override_projecting_gain = config.get(
            OVERRIDE_G_PROJECTING_GAIN, DEFAULT_OVERRIDE_G_PROJECTING_GAIN)
        self.projecting_gain_correction = 4. if self.override_projecting_gain else 1.
        self.truncate_weights = config.get(TRUNCATE_WEIGHTS,
                                           DEFAULT_TRUNCATE_WEIGHTS)
        self.G_fused_scale = config.get(G_FUSED_SCALE, DEFAULT_G_FUSED_SCALE)
        self.G_kernel_size = config.get(G_KERNEL_SIZE, DEFAULT_G_KERNEL_SIZE)
        self.G_fmap_base = config.get(G_FMAP_BASE, DEFAULT_FMAP_BASE)
        self.G_fmap_decay = config.get(G_FMAP_DECAY, DEFAULT_FMAP_DECAY)
        self.G_fmap_max = config.get(G_FMAP_MAX, DEFAULT_FMAP_MAX)
        self.batch_sizes = config[BATCH_SIZES]

        self.G_act_name = config.get(G_ACTIVATION, DEFAULT_G_ACTIVATION)
        if self.G_act_name in ACTIVATION_FUNS_DICT.keys():
            self.G_act = ACTIVATION_FUNS_DICT[self.G_act_name]
        else:
            assert False, f"Generator activation '{self.G_act_name}' not supported. See ACTIVATION_FUNS_DICT"

        self.use_mixed_precision = config.get(USE_MIXED_PRECISION,
                                              DEFAULT_USE_MIXED_PRECISION)
        if self.use_mixed_precision:
            self.policy = mixed_precision.Policy('mixed_float16')
            self.act_dtype = 'float32' if self.G_act_name in FP32_ACTIVATIONS else self.policy
            self.compute_dtype = self.policy.compute_dtype
        else:
            self.policy = 'float32'
            self.act_dtype = 'float32'
            self.compute_dtype = 'float32'

        self.weights_init_mode = config.get(G_WEIGHTS_INIT_MODE, None)
        if self.weights_init_mode is None:
            self.gain = GAIN_ACTIVATION_FUNS_DICT[self.G_act_name]
        else:
            self.gain = GAIN_INIT_MODE_DICT[self.weights_init_mode]

        # This constant is taken from the original implementation
        self.projecting_mult = 4
        if self.data_format == NCHW_FORMAT:
            self.z_dim = (self.latent_size, 1, 1)
            self.projecting_target_shape = (self.G_n_filters(2 - 1),
                                            self.projecting_mult,
                                            self.projecting_mult)
        elif self.data_format == NHWC_FORMAT:
            self.z_dim = (1, 1, self.latent_size)
            self.projecting_target_shape = (self.projecting_mult,
                                            self.projecting_mult,
                                            self.G_n_filters(2 - 1))

        self.create_model_layers()
    def __init__(self, config):

        self.target_resolution = config[TARGET_RESOLUTION]
        self.resolution_log2 = int(np.log2(self.target_resolution))
        assert self.target_resolution == 2**self.resolution_log2 and self.target_resolution >= 4

        self.start_resolution = config.get(START_RESOLUTION,
                                           DEFAULT_START_RESOLUTION)
        self.start_resolution_log2 = int(np.log2(self.start_resolution))
        assert self.start_resolution == 2**self.start_resolution_log2 and self.start_resolution >= 4

        self.min_lod = level_of_details(self.resolution_log2,
                                        self.resolution_log2)
        self.max_lod = level_of_details(self.start_resolution_log2,
                                        self.resolution_log2)

        self.data_format = config.get(DATA_FORMAT, DEFAULT_DATA_FORMAT)
        validate_data_format(self.data_format)

        self.use_bias = config.get(USE_BIAS, DEFAULT_USE_BIAS)
        self.use_wscale = config.get(USE_WSCALE, DEFAULT_USE_WSCALE)
        self.truncate_weights = config.get(TRUNCATE_WEIGHTS,
                                           DEFAULT_TRUNCATE_WEIGHTS)
        self.mbstd_group_size = config[MBSTD_GROUP_SIZE]
        self.D_fused_scale = config.get(D_FUSED_SCALE, DEFAULT_D_FUSED_SCALE)
        self.D_kernel_size = config.get(D_KERNEL_SIZE, DEFAULT_D_KERNEL_SIZE)
        self.D_fmap_base = config.get(D_FMAP_BASE, DEFAULT_FMAP_BASE)
        self.D_fmap_decay = config.get(D_FMAP_DECAY, DEFAULT_FMAP_DECAY)
        self.D_fmap_max = config.get(D_FMAP_MAX, DEFAULT_FMAP_MAX)
        self.batch_sizes = config[BATCH_SIZES]

        self.D_act_name = config.get(D_ACTIVATION, DEFAULT_D_ACTIVATION)
        if self.D_act_name in ACTIVATION_FUNS_DICT.keys():
            self.D_act = ACTIVATION_FUNS_DICT[self.D_act_name]
        else:
            assert False, f"Discriminator activation '{self.D_act_name}' not supported. See ACTIVATION_FUNS_DICT"

        self.use_mixed_precision = config.get(USE_MIXED_PRECISION,
                                              DEFAULT_USE_MIXED_PRECISION)
        if self.use_mixed_precision:
            self.policy = mixed_precision.Policy('mixed_float16')
            self.act_dtype = 'float32' if self.D_act_name in FP32_ACTIVATIONS else self.policy
            self.compute_dtype = self.policy.compute_dtype
        else:
            self.policy = 'float32'
            self.act_dtype = 'float32'
            self.compute_dtype = 'float32'

        self.weights_init_mode = config.get(D_WEIGHTS_INIT_MODE, None)
        if self.weights_init_mode is None:
            self.gain = GAIN_ACTIVATION_FUNS_DICT[self.D_act_name]
        else:
            self.gain = GAIN_INIT_MODE_DICT[self.weights_init_mode]

        # Might be useful to override number of units in projecting layer
        # in case latent size is not 512 to make models have almost the same number
        # of trainable params
        self.projecting_nf = config.get(D_PROJECTING_NF,
                                        self.D_n_filters(2 - 2))

        self.create_model_layers()
    def __init__(self,
                 config,
                 mode=DEFAULT_MODE,
                 images_paths=None,
                 res=None,
                 stage=None,
                 single_process_training=False):

        self.target_resolution = config[TARGET_RESOLUTION]
        self.resolution_log2 = int(np.log2(self.target_resolution))
        assert self.target_resolution == 2**self.resolution_log2 and self.target_resolution >= 4

        self.start_resolution = config.get(START_RESOLUTION,
                                           DEFAULT_START_RESOLUTION)
        self.start_resolution_log2 = int(np.log2(self.start_resolution))
        assert self.start_resolution == 2**self.start_resolution_log2 and self.start_resolution >= 4

        self.min_lod = level_of_details(self.resolution_log2,
                                        self.resolution_log2)
        self.max_lod = level_of_details(self.start_resolution_log2,
                                        self.resolution_log2)

        self.data_format = config.get(DATA_FORMAT, DEFAULT_DATA_FORMAT)
        validate_data_format(self.data_format)

        self.latent_size = config[LATENT_SIZE]
        if self.data_format == NCHW_FORMAT:
            self.z_dim = (self.latent_size, 1, 1)
        elif self.data_format == NHWC_FORMAT:
            self.z_dim = (1, 1, self.latent_size)

        self.model_name = config[MODEL_NAME]
        self.storage_path = config.get(STORAGE_PATH, DEFAULT_STORAGE_PATH)
        self.max_models_to_keep = config.get(MAX_MODELS_TO_KEEP,
                                             DEFAULT_MAX_MODELS_TO_KEEP)
        self.summary_every = config.get(SUMMARY_EVERY, DEFAULT_SUMMARY_EVERY)
        self.save_model_every = config.get(SAVE_MODEL_EVERY,
                                           DEFAULT_SAVE_MODEL_EVERY)
        self.save_images_every = config.get(SAVE_IMAGES_EVERY,
                                            DEFAULT_SAVE_IMAGES_EVERY)
        self.batch_sizes = config[BATCH_SIZES]
        self.fade_in_images = config[FADE_IN_IMAGES]
        self.stabilization_images = config[STABILIZATION_IMAGES]
        self.use_mixed_precision = config.get(USE_MIXED_PRECISION,
                                              DEFAULT_USE_MIXED_PRECISION)
        self.compute_dtype = 'float16' if self.use_mixed_precision else 'float32'
        self.use_Gs = config.get(USE_G_SMOOTHING, DEFAULT_USE_G_SMOOTHING)
        # Even in mixed precision training weights are stored in fp32
        self.smoothed_beta = tf.constant(config.get(G_SMOOTHING_BETA,
                                                    DEFAULT_G_SMOOTHING_BETA),
                                         dtype='float32')
        self.use_gpu_for_Gs = config.get(USE_GPU_FOR_GS,
                                         DEFAULT_USE_GPU_FOR_GS)
        self.shuffle_dataset = config.get(SHUFFLE_DATASET,
                                          DEFAULT_SHUFFLE_DATASET)
        self.dataset_n_parallel_calls = config.get(
            DATASET_N_PARALLEL_CALLS, DEFAULT_DATASET_N_PARALLEL_CALLS)
        self.dataset_n_prefetched_batches = config.get(
            DATASET_N_PREFETCHED_BATCHES, DEFAULT_DATASET_N_PREFETCHED_BATCHES)
        self.dataset_n_max_images = config.get(DATASET_N_MAX_IMAGES,
                                               DEFAULT_DATASET_N_MAX_IMAGES)
        self.dataset_max_cache_res = config.get(DATASET_MAX_CACHE_RES,
                                                DEFAULT_DATASET_MAX_CACHE_RES)
        self.mirror_augment = config.get(MIRROR_AUGMENT,
                                         DEFAULT_MIRROR_AUGMENT)

        self.G_learning_rate = config.get(G_LEARNING_RATE,
                                          DEFAULT_G_LEARNING_RATE)
        self.D_learning_rate = config.get(D_LEARNING_RATE,
                                          DEFAULT_D_LEARNING_RATE)
        self.G_learning_rate_dict = config.get(G_LEARNING_RATE_DICT,
                                               DEFAULT_G_LEARNING_RATE_DICT)
        self.D_learning_rate_dict = config.get(D_LEARNING_RATE_DICT,
                                               DEFAULT_D_LEARNING_RATE_DICT)
        self.beta1 = config.get(ADAM_BETA1, DEFAULT_ADAM_BETA1)
        self.beta2 = config.get(ADAM_BETA2, DEFAULT_ADAM_BETA2)
        self.reset_opt_state_for_new_lod = config.get(
            RESET_OPT_STATE_FOR_NEW_LOD, DEFAULT_RESET_OPT_STATE_FOR_NEW_LOD)

        self.valid_grid_nrows = 10
        self.valid_grid_ncols = 10
        self.valid_grid_padding = 2
        self.min_target_single_image_size = 2**7
        self.max_png_res = 5
        self.valid_latents = self.generate_latents(self.valid_grid_nrows *
                                                   self.valid_grid_ncols)

        self.logs_path = os.path.join(TF_LOGS_DIR, self.model_name)

        self.writers_dirs = {
            res: os.path.join(self.logs_path, '%dx%d' % (2**res, 2**res))
            for res in range(self.start_resolution_log2, self.resolution_log2 +
                             1)
        }

        self.summary_writers = {
            res: tf.summary.create_file_writer(self.writers_dirs[res])
            for res in range(self.start_resolution_log2, self.resolution_log2 +
                             1)
        }

        self.validate_config()

        self.clear_session_for_new_model = True

        self.G_object = Generator(config)
        self.D_object = Discriminator(config)

        if self.use_Gs:
            Gs_config = config
            self.Gs_valid_latents = self.valid_latents
            self.Gs_device = '/GPU:0' if self.use_gpu_for_Gs else '/CPU:0'

            if not self.use_gpu_for_Gs:
                Gs_config[DATA_FORMAT] = NHWC_FORMAT
                # NCHW -> NHWC
                self.toNHWC_axis = [0, 2, 3, 1]
                # NCHW -> NHWC -> NCHW
                self.toNCHW_axis = [0, 3, 1, 2]
                self.Gs_valid_latents = tf.transpose(self.valid_latents,
                                                     self.toNHWC_axis)

            self.Gs_object = Generator(Gs_config)

        if mode == INFERENCE_MODE:
            self.initialize_models()
        elif mode == TRAIN_MODE:
            if single_process_training:
                self.initialize_models()
                self.create_images_generators(config)
                self.initialize_optimizers(create_all_variables=True)
            else:
                self.initialize_models(res, stage)
                self.create_images_generator(config, res, images_paths)
                self.initialize_optimizers(create_all_variables=False)