def convert_official_discriminator_weights(ckpt_dir, use_custom_cuda): discriminator = load_discriminator(d_params=None, ckpt_dir=None, custom_cuda=use_custom_cuda) # restore official ones official_checkpoint = tf.train.latest_checkpoint('./official-pretrained') official_vars = tf.train.list_variables(official_checkpoint) # get name mapper name_mapper = variable_name_mapper_d(discriminator) for name_d, tvar in name_mapper.items(): print(f'{name_d}: {tvar.name}') # check shape check_shape(name_mapper, official_vars) # restore tf.compat.v1.train.init_from_checkpoint(official_checkpoint, assignment_map=name_mapper) # save ckpt = tf.train.Checkpoint(discriminator=discriminator) out_dir = os.path.join(ckpt_dir, 'discriminator') manager = tf.train.CheckpointManager(ckpt, out_dir, max_to_keep=1) manager.save(checkpoint_number=0) return
def initiate_models(g_params, d_params, use_custom_cuda): discriminator = load_discriminator(d_params, ckpt_dir=None, custom_cuda=use_custom_cuda) generator = load_generator(g_params=g_params, is_g_clone=False, ckpt_dir=None, custom_cuda=use_custom_cuda) g_clone = load_generator(g_params=g_params, is_g_clone=True, ckpt_dir=None, custom_cuda=use_custom_cuda) # set initial g_clone weights same as generator g_clone.set_weights(generator.get_weights()) return discriminator, generator, g_clone
def convert_official_weights_together(ckpt_dir, use_custom_cuda): # instantiate all models discriminator = load_discriminator(d_params=None, ckpt_dir=None, custom_cuda=use_custom_cuda) generator = load_generator(g_params=None, is_g_clone=False, ckpt_dir=None, custom_cuda=use_custom_cuda) g_clone = load_generator(g_params=None, is_g_clone=True, ckpt_dir=None, custom_cuda=use_custom_cuda) # restore official ones official_checkpoint = tf.train.latest_checkpoint('./official-pretrained') official_vars = tf.train.list_variables(official_checkpoint) for name, shape in official_vars: print(f'{name}: {shape}') # get name mapper name_mapper_d = variable_name_mapper_d(discriminator) name_mapper_g1 = variable_name_mapper_g(generator, is_g_clone=False) name_mapper_g2 = variable_name_mapper_g(g_clone, is_g_clone=True) name_mapper = {**name_mapper_d, **name_mapper_g1, **name_mapper_g2} # check shape check_shape(name_mapper, official_vars) # restore tf.compat.v1.train.init_from_checkpoint(official_checkpoint, assignment_map=name_mapper) # save ckpt = tf.train.Checkpoint(discriminator=discriminator, generator=generator, g_clone=g_clone) manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=1) manager.save(checkpoint_number=0) return