コード例 #1
0
 def create_model_layers(self):
     self.up_layers = {
         res: Upscale2d(factor=2,
                        dtype=self.policy,
                        data_format=self.data_format,
                        name='Upscale2D_%dx%d' % (2**res, 2**res))
         for res in range(2, self.resolution_log2 + 1)
     }
     self.G_wsum_layers = {
         lod: Weighted_sum(dtype=self.policy,
                           name=f'G_{WSUM_NAME}_{LOD_NAME}{lod}')
         for lod in range(self.min_lod, self.max_lod + 1)
     }
     self.G_input_layer = Input(shape=self.z_dim,
                                name='Latent_vector',
                                dtype=self.compute_dtype)
     self.G_latents_normalizer = Pixel_Norm(dtype=self.policy,
                                            data_format=self.data_format,
                                            name='Latents_normalizer')
     self.G_blocks = {
         res: self.G_block(res)
         for res in range(2, self.resolution_log2 + 1)
     }
     self.toRGB_layers = {
         level_of_details(res, self.resolution_log2): self.to_rgb(res)
         for res in range(2, self.resolution_log2 + 1)
     }
コード例 #2
0
 def create_model_layers(self):
     self.down_layers = {
         res: Downscale2d(factor=2,
                          dtype=self.policy,
                          data_format=self.data_format,
                          name='Downscale2D_%dx%d' % (2**res, 2**res))
         for res in range(2, self.resolution_log2 + 1)
     }
     self.D_wsum_layers = {
         lod: Weighted_sum(dtype=self.policy,
                           name=f'D_{WSUM_NAME}_{LOD_NAME}{lod}')
         for lod in range(self.min_lod, self.max_lod + 1)
     }
     self.D_input_layers = {
         res: Input(shape=self.D_input_shape(res),
                    dtype=self.compute_dtype,
                    name='Images_%dx%d' % (2**res, 2**res))
         for res in range(2, self.resolution_log2 + 1)
     }
     self.D_blocks = {
         res: self.D_block(res)
         for res in range(2, self.resolution_log2 + 1)
     }
     self.fromRGB_layers = {
         level_of_details(res, self.resolution_log2): self.from_rgb(res)
         for res in range(2, self.resolution_log2 + 1)
     }
コード例 #3
0
    def initialize_Gs_model(self, model_res=None, mode=None):
        start_time = time.time()
        print('Initializing smoothed generator...')

        if model_res is not None:
            self.Gs_object.initialize_G_model(model_res=model_res, mode=mode)

            for res in range(2, model_res + 1):
                self.Gs_object.G_blocks[res].set_weights(
                    self.G_object.G_blocks[res].get_weights())

            min_lod = level_of_details(model_res, self.resolution_log2)
            max_lod = min_lod + (1 if mode == FADE_IN_MODE else 0)

            for lod in range(min_lod, max_lod + 1):
                self.Gs_object.toRGB_layers[lod].set_weights(
                    self.G_object.toRGB_layers[lod].get_weights())
        else:
            self.Gs_object.initialize_G_model()

            for res in range(2, self.resolution_log2 + 1):
                self.Gs_object.G_blocks[res].set_weights(
                    self.G_object.G_blocks[res].get_weights())

            for lod in range(self.min_lod, self.max_lod + 1):
                self.Gs_object.toRGB_layers[lod].set_weights(
                    self.G_object.toRGB_layers[lod].get_weights())

        total_time = time.time() - start_time
        logging.info(
            f'Smoothed generator initialized in {total_time:.3f} seconds!')
コード例 #4
0
 def from_rgb(self, res):
     lod = level_of_details(res, self.resolution_log2)
     block_name = f'From{RGB_NAME}_{LOD_NAME}{lod}'
     conv_layers = [
         self.conv2d(fmaps=self.D_n_filters(res - 1), kernel_size=1),
         self.act()
     ]
     if self.use_bias: conv_layers = self.apply_bias(conv_layers)
     return tf.keras.Sequential(conv_layers, name=block_name)
コード例 #5
0
    def create_G_model(self, res, mode=STABILIZATION_MODE):
        assert mode in [FADE_IN_MODE, STABILIZATION_MODE
                        ], 'Mode ' + mode + ' is not supported'

        lod = level_of_details(res, self.resolution_log2)

        details_layers = [self.G_blocks[i] for i in range(2, res + 1)]
        if self.normalize_latents:
            details_layers = [self.G_latents_normalizer] + details_layers

        if res == 2:  # 4x4
            toRGB_layer = self.toRGB_layers[lod]
            images_layers = details_layers + [toRGB_layer]

            images_out = self.G_input_layer
            for layer in images_layers:
                images_out = layer(images_out)
        else:  # 8x8 and up
            if mode == FADE_IN_MODE:
                up_layer_name = 'Upscale2D_%dx%d_%s' % (2**res, 2**
                                                        res, FADE_IN_MODE)
                up_layer = Upscale2d(factor=2,
                                     dtype=self.policy,
                                     data_format=self.data_format,
                                     name=up_layer_name)
                toRGB_layer1 = self.toRGB_layers[lod + 1]
                toRGB_layer2 = self.toRGB_layers[lod]

                images = self.G_input_layer
                for layer in details_layers[:-1]:
                    images = layer(images)

                images1_layers = [toRGB_layer1, up_layer]
                images1 = images
                for layer in images1_layers:
                    images1 = layer(images1)

                images2_layers = [details_layers[-1], toRGB_layer2]
                images2 = images
                for layer in images2_layers:
                    images2 = layer(images2)

                images_out = self.G_wsum_layers[lod]([images1, images2])

            elif mode == STABILIZATION_MODE:
                toRGB_layer = self.toRGB_layers[lod]
                images_layers = details_layers + [toRGB_layer]
                images_out = self.G_input_layer
                for layer in images_layers:
                    images_out = layer(images_out)

        G_model = tf.keras.Model(inputs=self.G_input_layer,
                                 outputs=tf.identity(images_out,
                                                     name='Images_out'),
                                 name=f'G_model_{LOD_NAME}{lod}')
        return G_model
コード例 #6
0
    def create_D_model(self, res, mode=STABILIZATION_MODE):
        assert mode in [FADE_IN_MODE, STABILIZATION_MODE
                        ], 'Mode ' + mode + ' is not supported'

        lod = level_of_details(res, self.resolution_log2)

        details_layers = [self.D_blocks[i] for i in range(res, 2 - 1, -1)]
        D_input_layer = self.D_input_layers[res]
        if res == 2:  # 4x4
            fromRGB_layer = self.fromRGB_layers[lod]
            model_layers = [fromRGB_layer] + details_layers

            x = D_input_layer
            for layer in model_layers:
                x = layer(x)
        else:  # 8x8 and up
            if mode == FADE_IN_MODE:
                fromRGB_layer1 = self.fromRGB_layers[lod + 1]
                down_layer_name = 'Downscale2D_%dx%d_%s' % (2**res, 2**
                                                            res, FADE_IN_MODE)
                down_layer = Downscale2d(factor=2,
                                         dtype=self.policy,
                                         data_format=self.data_format,
                                         name=down_layer_name)
                branch1_layers = [down_layer, fromRGB_layer1]
                x1 = D_input_layer
                for layer in branch1_layers:
                    x1 = layer(x1)

                fromRGB_layer2 = self.fromRGB_layers[lod]
                branch2_layers = [fromRGB_layer2, details_layers[0]]
                x2 = D_input_layer
                for layer in branch2_layers:
                    x2 = layer(x2)

                x = self.D_wsum_layers[lod]([x1, x2])
                for layer in details_layers[1:]:
                    x = layer(x)

            elif mode == STABILIZATION_MODE:
                fromRGB_layer = self.fromRGB_layers[lod]
                model_layers = [fromRGB_layer] + details_layers

                x = D_input_layer
                for layer in model_layers:
                    x = layer(x)

        D_model = tf.keras.Model(inputs=D_input_layer,
                                 outputs=tf.identity(x, name='Scores'),
                                 name=f'D_model_{LOD_NAME}{lod}')
        return D_model
コード例 #7
0
    def __init__(self, config):

        self.target_resolution = config[TARGET_RESOLUTION]
        self.resolution_log2 = int(np.log2(self.target_resolution))
        assert self.target_resolution == 2**self.resolution_log2 and self.target_resolution >= 4

        self.start_resolution = config.get(START_RESOLUTION,
                                           DEFAULT_START_RESOLUTION)
        self.start_resolution_log2 = int(np.log2(self.start_resolution))
        assert self.start_resolution == 2**self.start_resolution_log2 and self.start_resolution >= 4

        self.min_lod = level_of_details(self.resolution_log2,
                                        self.resolution_log2)
        self.max_lod = level_of_details(self.start_resolution_log2,
                                        self.resolution_log2)

        self.data_format = config.get(DATA_FORMAT, DEFAULT_DATA_FORMAT)
        validate_data_format(self.data_format)

        self.latent_size = config[LATENT_SIZE]
        self.normalize_latents = config.get(NORMALIZE_LATENTS,
                                            DEFAULT_NORMALIZE_LATENTS)
        self.use_bias = config.get(USE_BIAS, DEFAULT_USE_BIAS)
        self.use_wscale = config.get(USE_WSCALE, DEFAULT_USE_WSCALE)
        self.use_pixelnorm = config.get(USE_PIXELNORM, DEFAULT_USE_PIXELNORM)
        self.override_projecting_gain = config.get(
            OVERRIDE_G_PROJECTING_GAIN, DEFAULT_OVERRIDE_G_PROJECTING_GAIN)
        self.projecting_gain_correction = 4. if self.override_projecting_gain else 1.
        self.truncate_weights = config.get(TRUNCATE_WEIGHTS,
                                           DEFAULT_TRUNCATE_WEIGHTS)
        self.G_fused_scale = config.get(G_FUSED_SCALE, DEFAULT_G_FUSED_SCALE)
        self.G_kernel_size = config.get(G_KERNEL_SIZE, DEFAULT_G_KERNEL_SIZE)
        self.G_fmap_base = config.get(G_FMAP_BASE, DEFAULT_FMAP_BASE)
        self.G_fmap_decay = config.get(G_FMAP_DECAY, DEFAULT_FMAP_DECAY)
        self.G_fmap_max = config.get(G_FMAP_MAX, DEFAULT_FMAP_MAX)
        self.batch_sizes = config[BATCH_SIZES]

        self.G_act_name = config.get(G_ACTIVATION, DEFAULT_G_ACTIVATION)
        if self.G_act_name in ACTIVATION_FUNS_DICT.keys():
            self.G_act = ACTIVATION_FUNS_DICT[self.G_act_name]
        else:
            assert False, f"Generator activation '{self.G_act_name}' not supported. See ACTIVATION_FUNS_DICT"

        self.use_mixed_precision = config.get(USE_MIXED_PRECISION,
                                              DEFAULT_USE_MIXED_PRECISION)
        if self.use_mixed_precision:
            self.policy = mixed_precision.Policy('mixed_float16')
            self.act_dtype = 'float32' if self.G_act_name in FP32_ACTIVATIONS else self.policy
            self.compute_dtype = self.policy.compute_dtype
        else:
            self.policy = 'float32'
            self.act_dtype = 'float32'
            self.compute_dtype = 'float32'

        self.weights_init_mode = config.get(G_WEIGHTS_INIT_MODE, None)
        if self.weights_init_mode is None:
            self.gain = GAIN_ACTIVATION_FUNS_DICT[self.G_act_name]
        else:
            self.gain = GAIN_INIT_MODE_DICT[self.weights_init_mode]

        # This constant is taken from the original implementation
        self.projecting_mult = 4
        if self.data_format == NCHW_FORMAT:
            self.z_dim = (self.latent_size, 1, 1)
            self.projecting_target_shape = (self.G_n_filters(2 - 1),
                                            self.projecting_mult,
                                            self.projecting_mult)
        elif self.data_format == NHWC_FORMAT:
            self.z_dim = (1, 1, self.latent_size)
            self.projecting_target_shape = (self.projecting_mult,
                                            self.projecting_mult,
                                            self.G_n_filters(2 - 1))

        self.create_model_layers()
コード例 #8
0
    def __init__(self, config):

        self.target_resolution = config[TARGET_RESOLUTION]
        self.resolution_log2 = int(np.log2(self.target_resolution))
        assert self.target_resolution == 2**self.resolution_log2 and self.target_resolution >= 4

        self.start_resolution = config.get(START_RESOLUTION,
                                           DEFAULT_START_RESOLUTION)
        self.start_resolution_log2 = int(np.log2(self.start_resolution))
        assert self.start_resolution == 2**self.start_resolution_log2 and self.start_resolution >= 4

        self.min_lod = level_of_details(self.resolution_log2,
                                        self.resolution_log2)
        self.max_lod = level_of_details(self.start_resolution_log2,
                                        self.resolution_log2)

        self.data_format = config.get(DATA_FORMAT, DEFAULT_DATA_FORMAT)
        validate_data_format(self.data_format)

        self.use_bias = config.get(USE_BIAS, DEFAULT_USE_BIAS)
        self.use_wscale = config.get(USE_WSCALE, DEFAULT_USE_WSCALE)
        self.truncate_weights = config.get(TRUNCATE_WEIGHTS,
                                           DEFAULT_TRUNCATE_WEIGHTS)
        self.mbstd_group_size = config[MBSTD_GROUP_SIZE]
        self.D_fused_scale = config.get(D_FUSED_SCALE, DEFAULT_D_FUSED_SCALE)
        self.D_kernel_size = config.get(D_KERNEL_SIZE, DEFAULT_D_KERNEL_SIZE)
        self.D_fmap_base = config.get(D_FMAP_BASE, DEFAULT_FMAP_BASE)
        self.D_fmap_decay = config.get(D_FMAP_DECAY, DEFAULT_FMAP_DECAY)
        self.D_fmap_max = config.get(D_FMAP_MAX, DEFAULT_FMAP_MAX)
        self.batch_sizes = config[BATCH_SIZES]

        self.D_act_name = config.get(D_ACTIVATION, DEFAULT_D_ACTIVATION)
        if self.D_act_name in ACTIVATION_FUNS_DICT.keys():
            self.D_act = ACTIVATION_FUNS_DICT[self.D_act_name]
        else:
            assert False, f"Discriminator activation '{self.D_act_name}' not supported. See ACTIVATION_FUNS_DICT"

        self.use_mixed_precision = config.get(USE_MIXED_PRECISION,
                                              DEFAULT_USE_MIXED_PRECISION)
        if self.use_mixed_precision:
            self.policy = mixed_precision.Policy('mixed_float16')
            self.act_dtype = 'float32' if self.D_act_name in FP32_ACTIVATIONS else self.policy
            self.compute_dtype = self.policy.compute_dtype
        else:
            self.policy = 'float32'
            self.act_dtype = 'float32'
            self.compute_dtype = 'float32'

        self.weights_init_mode = config.get(D_WEIGHTS_INIT_MODE, None)
        if self.weights_init_mode is None:
            self.gain = GAIN_ACTIVATION_FUNS_DICT[self.D_act_name]
        else:
            self.gain = GAIN_INIT_MODE_DICT[self.weights_init_mode]

        # Might be useful to override number of units in projecting layer
        # in case latent size is not 512 to make models have almost the same number
        # of trainable params
        self.projecting_nf = config.get(D_PROJECTING_NF,
                                        self.D_n_filters(2 - 2))

        self.create_model_layers()
コード例 #9
0
 def to_rgb(self, res):
     lod = level_of_details(res, self.resolution_log2)
     block_name = f'To{RGB_NAME}_{LOD_NAME}{lod}'
     conv_layers = [self.conv2d(fmaps=3, kernel_size=1, gain=1.)]
     if self.use_bias: conv_layers = self.apply_bias(conv_layers)
     return tf.keras.Sequential(conv_layers, name=block_name)
コード例 #10
0
    def __init__(self,
                 config,
                 mode=DEFAULT_MODE,
                 images_paths=None,
                 res=None,
                 stage=None,
                 single_process_training=False):

        self.target_resolution = config[TARGET_RESOLUTION]
        self.resolution_log2 = int(np.log2(self.target_resolution))
        assert self.target_resolution == 2**self.resolution_log2 and self.target_resolution >= 4

        self.start_resolution = config.get(START_RESOLUTION,
                                           DEFAULT_START_RESOLUTION)
        self.start_resolution_log2 = int(np.log2(self.start_resolution))
        assert self.start_resolution == 2**self.start_resolution_log2 and self.start_resolution >= 4

        self.min_lod = level_of_details(self.resolution_log2,
                                        self.resolution_log2)
        self.max_lod = level_of_details(self.start_resolution_log2,
                                        self.resolution_log2)

        self.data_format = config.get(DATA_FORMAT, DEFAULT_DATA_FORMAT)
        validate_data_format(self.data_format)

        self.latent_size = config[LATENT_SIZE]
        if self.data_format == NCHW_FORMAT:
            self.z_dim = (self.latent_size, 1, 1)
        elif self.data_format == NHWC_FORMAT:
            self.z_dim = (1, 1, self.latent_size)

        self.model_name = config[MODEL_NAME]
        self.storage_path = config.get(STORAGE_PATH, DEFAULT_STORAGE_PATH)
        self.max_models_to_keep = config.get(MAX_MODELS_TO_KEEP,
                                             DEFAULT_MAX_MODELS_TO_KEEP)
        self.summary_every = config.get(SUMMARY_EVERY, DEFAULT_SUMMARY_EVERY)
        self.save_model_every = config.get(SAVE_MODEL_EVERY,
                                           DEFAULT_SAVE_MODEL_EVERY)
        self.save_images_every = config.get(SAVE_IMAGES_EVERY,
                                            DEFAULT_SAVE_IMAGES_EVERY)
        self.batch_sizes = config[BATCH_SIZES]
        self.fade_in_images = config[FADE_IN_IMAGES]
        self.stabilization_images = config[STABILIZATION_IMAGES]
        self.use_mixed_precision = config.get(USE_MIXED_PRECISION,
                                              DEFAULT_USE_MIXED_PRECISION)
        self.compute_dtype = 'float16' if self.use_mixed_precision else 'float32'
        self.use_Gs = config.get(USE_G_SMOOTHING, DEFAULT_USE_G_SMOOTHING)
        # Even in mixed precision training weights are stored in fp32
        self.smoothed_beta = tf.constant(config.get(G_SMOOTHING_BETA,
                                                    DEFAULT_G_SMOOTHING_BETA),
                                         dtype='float32')
        self.use_gpu_for_Gs = config.get(USE_GPU_FOR_GS,
                                         DEFAULT_USE_GPU_FOR_GS)
        self.shuffle_dataset = config.get(SHUFFLE_DATASET,
                                          DEFAULT_SHUFFLE_DATASET)
        self.dataset_n_parallel_calls = config.get(
            DATASET_N_PARALLEL_CALLS, DEFAULT_DATASET_N_PARALLEL_CALLS)
        self.dataset_n_prefetched_batches = config.get(
            DATASET_N_PREFETCHED_BATCHES, DEFAULT_DATASET_N_PREFETCHED_BATCHES)
        self.dataset_n_max_images = config.get(DATASET_N_MAX_IMAGES,
                                               DEFAULT_DATASET_N_MAX_IMAGES)
        self.dataset_max_cache_res = config.get(DATASET_MAX_CACHE_RES,
                                                DEFAULT_DATASET_MAX_CACHE_RES)
        self.mirror_augment = config.get(MIRROR_AUGMENT,
                                         DEFAULT_MIRROR_AUGMENT)

        self.G_learning_rate = config.get(G_LEARNING_RATE,
                                          DEFAULT_G_LEARNING_RATE)
        self.D_learning_rate = config.get(D_LEARNING_RATE,
                                          DEFAULT_D_LEARNING_RATE)
        self.G_learning_rate_dict = config.get(G_LEARNING_RATE_DICT,
                                               DEFAULT_G_LEARNING_RATE_DICT)
        self.D_learning_rate_dict = config.get(D_LEARNING_RATE_DICT,
                                               DEFAULT_D_LEARNING_RATE_DICT)
        self.beta1 = config.get(ADAM_BETA1, DEFAULT_ADAM_BETA1)
        self.beta2 = config.get(ADAM_BETA2, DEFAULT_ADAM_BETA2)
        self.reset_opt_state_for_new_lod = config.get(
            RESET_OPT_STATE_FOR_NEW_LOD, DEFAULT_RESET_OPT_STATE_FOR_NEW_LOD)

        self.valid_grid_nrows = 10
        self.valid_grid_ncols = 10
        self.valid_grid_padding = 2
        self.min_target_single_image_size = 2**7
        self.max_png_res = 5
        self.valid_latents = self.generate_latents(self.valid_grid_nrows *
                                                   self.valid_grid_ncols)

        self.logs_path = os.path.join(TF_LOGS_DIR, self.model_name)

        self.writers_dirs = {
            res: os.path.join(self.logs_path, '%dx%d' % (2**res, 2**res))
            for res in range(self.start_resolution_log2, self.resolution_log2 +
                             1)
        }

        self.summary_writers = {
            res: tf.summary.create_file_writer(self.writers_dirs[res])
            for res in range(self.start_resolution_log2, self.resolution_log2 +
                             1)
        }

        self.validate_config()

        self.clear_session_for_new_model = True

        self.G_object = Generator(config)
        self.D_object = Discriminator(config)

        if self.use_Gs:
            Gs_config = config
            self.Gs_valid_latents = self.valid_latents
            self.Gs_device = '/GPU:0' if self.use_gpu_for_Gs else '/CPU:0'

            if not self.use_gpu_for_Gs:
                Gs_config[DATA_FORMAT] = NHWC_FORMAT
                # NCHW -> NHWC
                self.toNHWC_axis = [0, 2, 3, 1]
                # NCHW -> NHWC -> NCHW
                self.toNCHW_axis = [0, 3, 1, 2]
                self.Gs_valid_latents = tf.transpose(self.valid_latents,
                                                     self.toNHWC_axis)

            self.Gs_object = Generator(Gs_config)

        if mode == INFERENCE_MODE:
            self.initialize_models()
        elif mode == TRAIN_MODE:
            if single_process_training:
                self.initialize_models()
                self.create_images_generators(config)
                self.initialize_optimizers(create_all_variables=True)
            else:
                self.initialize_models(res, stage)
                self.create_images_generator(config, res, images_paths)
                self.initialize_optimizers(create_all_variables=False)
コード例 #11
0
    def initialize_D_optimizer(self, create_all_variables):
        self.D_optimizer = tf.keras.optimizers.Adam(
            learning_rate=self.D_learning_rate,
            beta_1=self.beta1,
            beta_2=self.beta2,
            epsilon=1e-8,
            name='D_Adam')
        if self.use_mixed_precision:
            self.D_optimizer = mixed_precision.LossScaleOptimizer(
                self.D_optimizer, 'dynamic')

        if not create_all_variables:
            return

        # First step: create optimizer states for all internal and final output layers
        step = tf.Variable(0, trainable=False, dtype=tf.int64)
        write_summary = tf.Variable(False, trainable=False, dtype=tf.bool)

        res = self.resolution_log2
        batch_size = self.batch_sizes[str(res)]

        G_model = self.G_object.create_G_model(res, mode=FADE_IN_MODE)
        D_model = self.D_object.create_D_model(res, mode=FADE_IN_MODE)

        latents = self.generate_latents(batch_size)
        D_input_shape = self.D_object.D_input_shape(res)
        images = tf.random.normal(shape=(batch_size, ) + D_input_shape,
                                  mean=0.,
                                  stddev=0.05,
                                  dtype=self.compute_dtype)

        D_vars = D_model.trainable_variables
        with tf.GradientTape(watch_accessed_variables=False) as D_tape:
            D_tape.watch(D_vars)

            fake_images = G_model(latents)
            fake_scores = fp32(D_model(fake_images))
            real_scores = fp32(D_model(images))

            D_loss = D_loss_fn(D_model,
                               optimizer=self.D_optimizer,
                               mixed_precision=self.use_mixed_precision,
                               real_scores=real_scores,
                               real_images=images,
                               fake_scores=fake_scores,
                               fake_images=fake_images,
                               write_summary=write_summary,
                               step=step)
            D_loss = scale_loss(D_loss, self.D_optimizer,
                                self.use_mixed_precision)
            print('D loss computed')

        # No need to update weights!
        D_grads = mult_by_zero(D_tape.gradient(D_loss, D_vars))
        D_grads = unscale_grads(D_grads, self.D_optimizer,
                                self.use_mixed_precision)
        print('D gradients obtained')

        self.D_optimizer.apply_gradients(zip(D_grads, D_vars))
        print('D gradients applied')

        print('Creating slots for intermediate output layers...')
        for res in range(self.start_resolution_log2, self.resolution_log2 + 1):
            lod = level_of_details(res, self.resolution_log2)

            # fromRGB layers
            from_layer = self.D_object.fromRGB_layers[lod]
            from_input_shape = (batch_size, ) + from_layer.input_shape[1:]
            from_inputs = tf.random.normal(shape=from_input_shape,
                                           mean=0.,
                                           stddev=0.05,
                                           dtype=self.compute_dtype)

            with tf.GradientTape() as tape:
                from_outputs = from_layer(from_inputs)
                loss = tf.reduce_mean(tf.square(from_outputs))
                loss = scale_loss(loss, self.D_optimizer,
                                  self.use_mixed_precision)

            D_vars = from_layer.trainable_variables
            # No need to update weights!
            D_grads = mult_by_zero(tape.gradient(loss, D_vars))
            D_grads = unscale_grads(D_grads, self.D_optimizer,
                                    self.use_mixed_precision)
            self.D_optimizer.apply_gradients(zip(D_grads, D_vars))

        print('D optimizer slots created!')