def blend_models(low_res_wrapper, high_res_wrapper, resolution, level, blend=0): low_res_model = low_res_wrapper.Gs high_res_model = high_res_wrapper.Gs result_model = low_res_model.clone() resolution = f'{resolution}x{resolution}' low_res_names = ModelBlender.extract_conv_names(low_res_model) high_res_names = ModelBlender.extract_conv_names(high_res_model) short_names = [(x[1:3]) for x in low_res_names] full_names = [(x[0]) for x in low_res_names] mid_point_idx = short_names.index((resolution, level)) mid_point_pos = low_res_names[mid_point_idx][3] ys = [] for name, resolution, level, position in low_res_names: x = position - mid_point_pos exponent = -x/blend y = 1 / (1 + math.exp(exponent)) ys.append(y) tfutil.set_vars( tfutil.run( {result_model.vars[name]: (high_res_model.vars[name] * y + low_res_model.vars[name] * (1-y)) for name, y in zip(full_names, ys)} ) ) return StyleGanWrapper().set_network((low_res_wrapper._G, low_res_wrapper._D, result_model))
def copy_weights(src_net, tgt_net, vars_to_copy): names = [ name for name in tgt_net.trainables.keys() if name in vars_to_copy ] tfutil.set_vars( tfutil.run({tgt_net.vars[name]: src_net.vars[name] for name in names}))
def copy_and_crop_or_pad_trainables(src_net, tgt_net) -> None: source_trainables = src_net.trainables.keys() target_trainables = tgt_net.trainables.keys() names = [pair for pair in zip(source_trainables, target_trainables)] skip = [] pbar = ProgressBar(len(names)) for pair in names: source_name, target_name = pair log = source_name x = src_net.get_var(source_name) y = tgt_net.get_var(target_name) source_shape = x.shape target_shape = y.shape if source_shape != target_shape: update = x index = None if 'Dense' in source_name: if source_shape[0] > target_shape[0]: gap = source_shape[0] - target_shape[0] start = abs(gap) // 2 end = start + target_shape[0] update = update[start:end, :] else: update = pad_symm_np(update, target_shape) log = (log, source_shape, '=>', target_shape) else: try: if source_shape[2] > target_shape[2]: index = 2 gap = source_shape[index] - target_shape[index] start = abs(gap) // 2 end = start + target_shape[index] update = update[:, :, start:end, :] if source_shape[3] > target_shape[3]: index = 3 gap = source_shape[index] - target_shape[index] start = abs(gap) // 2 end = start + target_shape[index] update = update[:, :, :, start:end] except: print(' Wrong var pair?', source_name, source_shape, target_name, target_shape) exit(1) if source_shape[2] < target_shape[2] or source_shape[ 3] < target_shape[3]: update = pad_symm_np(update, target_shape[2:]) log = (log, source_shape, '=>', target_shape) # print(pair, source_shape, target_shape) tgt_net.set_var(target_name, update) skip.append(source_name) pbar.upd(pair) weights_to_copy = { tgt_net.vars[pair[1]]: src_net.vars[pair[0]] for pair in names if pair[0] not in skip } tfutil.set_vars(tfutil.run(weights_to_copy))
def copy_own_vars_from(self, src_net: "Network") -> None: """Copy the values of all variables from the given network, excluding sub-networks.""" names = [ name for name in self.own_vars.keys() if name in src_net.own_vars ] tfutil.set_vars( tfutil.run({self.vars[name]: src_net.vars[name] for name in names}))
def apply_denominator(dst_net, denominator): denominator_inv = 1.0 / denominator names = [name for name in dst_net.trainables.keys()] tfutil.set_vars( tfutil.run({ dst_net.vars[name]: dst_net.vars[name] * denominator_inv for name in names })) return dst_net
def blend_models(model_1, model_2, resolution_h, resolution_w, res_log2, min_h, min_w, level, blend_width=None, verbose=False): # y is the blending amount which y = 0 means all model 1, y = 1 means all model_2 # TODO add small x offset for smoother blend animations resolution = f"{resolution_h}x{resolution_w}" model_1_names = extract_conv_names(model_1, res_log2, min_h, min_w) model_2_names = extract_conv_names(model_2, res_log2, min_h, min_w) assert all((x == y for x, y in zip(model_1_names, model_2_names))) model_out = model_1.clone() short_names = [(x[1:3]) for x in model_1_names] full_names = [(x[0]) for x in model_1_names] mid_point_idx = short_names.index((resolution, level)) mid_point_pos = model_1_names[mid_point_idx][3] ys = [] for name, resolution, level, position in model_1_names: # low to high (res) x = position - mid_point_pos if blend_width: exponent = -x / blend_width y = 1 / (1 + math.exp(exponent)) else: y = 1 if x > 1 else 0 ys.append(y) if verbose: print(f"Blending {name} by {y}") tfutil.set_vars( tfutil.run({ model_out.vars[name]: (model_2.vars[name] * y + model_1.vars[name] * (1 - y)) for name, y in zip(full_names, ys) })) return model_out
def copy_vars(src_net, tgt_net, D=False): names = [ name for name in tgt_net.trainables.keys() if name in src_net.trainables.keys() ] var_dict = OrderedDict() for name in names: if tgt_net.vars[name].shape == src_net.vars[ name].shape: # fixing rgb-to-rgba only !! var_dict[name] = src_net.vars[name] else: var_dict[name] = add_channel(src_net.vars[name], D=D) weights_to_copy = {tgt_net.vars[name]: var_dict[name] for name in names} tfutil.set_vars(tfutil.run(weights_to_copy))
def copy_compatible_trainables_from(dst_net, src_net) -> None: """Copy the compatible values of all trainable variables from the given network, including sub-networks""" names = [] for name in dst_net.trainables.keys(): if name not in src_net.trainables: print("Not restoring (not present): {}".format(name)) elif dst_net.trainables[name].shape != src_net.trainables[name].shape: print("Not restoring (different shape): {}".format(name)) elif name in src_net.trainables and dst_net.trainables[ name].shape == src_net.trainables[name].shape: print("Restoring: {}".format(name)) names.append(name) tfutil.set_vars( tfutil.run({dst_net.vars[name]: src_net.vars[name] for name in names}))
def slow_blend_from_saved_weights(net, alpha=0.9, verbose=False): if verbose: print("first net.vars", net.vars) #print("second net.vars", second_net.vars) for tensor_key in net.vars: blended_dicts = {} #if "G_synthesis" in tensor_key and (("_up" not in tensor_key) or ("ToRGB" in tensor_key)): #if "G_synthesis" in tensor_key and "1024x1024" in tensor_key: if "G_synthesis" in tensor_key: # so far it seems like for StyleGan2 we ... # - need ToRGB, need both the ones without _up and with _up try: if verbose: print(tensor_key) first_net_weights = weightsfirstnet[ tensor_key] # net.get_var(tensor_key) second_net_weights = weights[ tensor_key] # second_net.get_var(tensor_key) blended_weights = (1.0 - alpha) * first_net_weights + ( alpha) * second_net_weights blended_weights = np.copy(blended_weights) #net.set_var(tensor_key, np.copy( blended_weights )) v = net.find_var(tensor_key) tfutil.set_vars({v: blended_weights}) #blended_dicts[tensor_key] = np.copy( blended_weights ) """ start = timer() end = timer() time = (end - start) print("Save to net " + str(time) + "s") """ except Exception as e: print("--failed on tensor", tensor_key, "with:", e) """ blended_dicts_tmp = {} for key, value in blended_dicts: k = net.find_var(key) blended_dicts_tmp[k] = value tfutil.set_vars({k: value}) # < try without too #tfutil.set_vars(blended_dicts_tmp) """ return net
def main(): os.makedirs(a.out_dir, exist_ok=True) tflib.init_tf() with tf.Session() as sess, tf.device('/gpu:0'): Gs_lo = load_pkl(a.pkl1) Gs_hi = load_pkl(a.pkl2) # TODO add small x offset for smoother blend animations resolution = "{}x{}".format(a.res, a.res) model_1_names = extract_conv_names(Gs_lo) model_2_names = extract_conv_names(Gs_hi) assert all((x == y for x, y in zip(model_1_names, model_2_names))) Gs_out = Gs_lo.clone() short_names = [(x[1:3]) for x in model_1_names] full_names = [(x[0]) for x in model_1_names] mid_point_idx = short_names.index((resolution, a.level)) mid_point_pos = model_1_names[mid_point_idx][3] ys = [] for name, resolution, level, position in model_1_names: x = position - mid_point_pos if a.blend_width is not None: exponent = -x / a.blend_width y = 1 / (1 + math.exp(exponent)) else: y = 1 if x > 1 else 0 ys.append(y) if a.verbose is True: print("Blending {} by {}".format(name, y)) tfutil.set_vars(tfutil.run({ Gs_out.vars[name]: (Gs_hi.vars[name] * y + Gs_lo.vars[name] * (1-y)) for name, y in zip(full_names, ys)} )) out_name = os.path.join(a.out_dir, '%s-%s-%d-%d' % (basename(a.pkl1).split('-')[0], basename(a.pkl2).split('-')[0], a.res, a.level)) save_pkl(Gs_out, '%s.pkl' % out_name) rnd = np.random.RandomState(696) grid_latents = rnd.randn(4, *Gs_lo.input_shape[1:]) grid_fakes = Gs_out.run(grid_latents, [None], is_validation=True, minibatch_size=1) grid_fakes = np.hstack(np.transpose(grid_fakes, [0,2,3,1])) imsave('%s.jpg' % out_name, ((grid_fakes+1)*127.5).astype(np.uint8))
def add_networks(dst_net, src_net): names = [] for name in dst_net.trainables.keys(): if name not in src_net.trainables: print('Not restoring (not present): {}'.format(name)) elif dst_net.trainables[name].shape != src_net.trainables[name].shape: print('Not restoring (different shape): {}'.format(name)) if name in src_net.trainables and dst_net.trainables[ name].shape == src_net.trainables[name].shape: names.append(name) tfutil.set_vars( tfutil.run({ dst_net.vars[name]: dst_net.vars[name] + src_net.vars[name] for name in names })) return dst_net
def copy_and_crop_trainables_from(target_net, source_net) -> None: source_trainables = source_net.trainables.keys() target_trainables = target_net.trainables.keys() names = [pair for pair in zip(source_trainables, target_trainables)] skip = [] for pair in names: source_name, target_name = pair x = source_net.get_var(source_name) y = target_net.get_var(target_name) source_shape = x.shape target_shape = y.shape if source_shape != target_shape: update = x index = None if 'Dense' in source_name: index = 0 gap = source_shape[index] - target_shape[index] start = abs(gap) // 2 end = start + target_shape[index] update = update[start:end, :] else: if source_shape[2] != target_shape[2]: index = 2 gap = source_shape[index] - target_shape[index] start = abs(gap) // 2 end = start + target_shape[index] update = update[:, :, start:end, :] if source_shape[3] != target_shape[3]: index = 3 gap = source_shape[index] - target_shape[index] start = abs(gap) // 2 end = start + target_shape[index] update = update[:, :, :, start:end] target_net.set_var(target_name, update) skip.append(source_name) weights_to_copy = { target_net.vars[pair[1]]: source_net.vars[pair[0]] for pair in names if pair[0] not in skip } tfutil.set_vars(tfutil.run(weights_to_copy))
def copy_and_fill_trainables(src_net, tgt_net) -> None: # model => conditional train_vars = [ name for name in src_net.trainables.keys() if name in tgt_net.trainables.keys() ] skip = [] pbar = ProgressBar(len(train_vars)) for name in train_vars: x = src_net.get_var(name) y = tgt_net.get_var(name) src_shape = x.shape tgt_shape = y.shape if src_shape != tgt_shape: assert len(src_shape) == len( tgt_shape), "Different shapes: %s %s" % (str(src_shape), str(tgt_shape)) if np.less( tgt_shape, src_shape).any(): # kill labels: [1024,512] => [512,512] try: update = x[:tgt_shape[0], :tgt_shape[1], ...] # !!! corrects only first two dims except: update = x[:tgt_shape[0]] elif np.greater( tgt_shape, src_shape).any(): # add labels: [512,512] => [1024,512] tile_count = [ tgt_shape[i] // src_shape[i] for i in range(len(src_shape)) ] if a.verbose is True: print(name, tile_count, src_shape, '=>', tgt_shape, '\n\n') # G_mapping/Dense0, D/Output update = np.tile(x, tile_count) tgt_net.set_var(name, update) skip.append(name) pbar.upd(name) weights_to_copy = { tgt_net.vars[name]: src_net.vars[name] for name in train_vars if name not in skip } tfutil.set_vars(tfutil.run(weights_to_copy))
def blend_layers(Net_lo, Net_hi, type='G'): print(' blending', type) resolution = "{}x{}".format(a.res, a.res) model_1_names = extract_conv_names(Net_lo, type) model_2_names = extract_conv_names(Net_hi, type) assert all((x == y for x, y in zip(model_1_names, model_2_names))) Net_out = Net_lo.clone() short_names = [(x[1:3]) for x in model_1_names] full_names = [(x[0]) for x in model_1_names] mid_point_idx = short_names.index((resolution, a.level)) mid_point_pos = model_1_names[mid_point_idx][3] print(' boundary ::', mid_point_idx, mid_point_pos, model_1_names[mid_point_idx]) ys = [] for name, resolution, level, position in model_1_names: # print(name, resolution, level, position) # add small x offset for smoother blend animations ? x = position - mid_point_pos if a.blend_width is not None: exponent = -x / a.blend_width y = 1 / (1 + math.exp(exponent)) else: y = 1 if x > 1 else 0 ys.append(y) if a.verbose and y > 0: print(" .. {} *{}".format(name, y)) tfutil.set_vars( tfutil.run({ Net_out.vars[name]: (Net_hi.vars[name] * y + Net_lo.vars[name] * (1 - y)) for name, y in zip(full_names, ys) })) return Net_out
def __setstate__(self, state: dict) -> None: """Pickle import.""" # pylint: disable=attribute-defined-outside-init tfutil.assert_tf_initialized() self._init_fields() # Execute custom import handlers. for handler in _import_handlers: state = handler(state) # Set basic fields. assert state["version"] in [2, 3] self.name = state["name"] self.static_kwargs = util.EasyDict(state["static_kwargs"]) self.components = util.EasyDict(state.get("components", {})) self._build_module_src = state["build_module_src"] self._build_func_name = state["build_func_name"] # Create temporary module from the imported source code. module_name = "_tflib_network_import_" + uuid.uuid4().hex module = types.ModuleType(module_name) sys.modules[module_name] = module _import_module_src[module] = self._build_module_src exec(self._build_module_src, module.__dict__) # pylint: disable=exec-used # Locate network build function in the temporary module. self._build_func = util.get_obj_from_module(module, self._build_func_name) assert callable(self._build_func) # Init TensorFlow graph. self._init_graph() self.reset_own_vars() tfutil.set_vars( {self.find_var(name): value for name, value in state["variables"]})
def set_var(self, var_or_local_name: Union[TfExpression, str], new_value: Union[int, float, np.ndarray]) -> None: """Set the value of a given variable based on the given NumPy array. Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible.""" tfutil.set_vars({self.find_var(var_or_local_name): new_value})