Esempio n. 1
0
	def blend_models(low_res_wrapper, high_res_wrapper, resolution, level, blend=0):
		low_res_model = low_res_wrapper.Gs
		high_res_model = high_res_wrapper.Gs

		result_model = low_res_model.clone()

		resolution = f'{resolution}x{resolution}'

		low_res_names = ModelBlender.extract_conv_names(low_res_model)
		high_res_names = ModelBlender.extract_conv_names(high_res_model)

		short_names = [(x[1:3]) for x in low_res_names]
		full_names = [(x[0]) for x in low_res_names]
		mid_point_idx = short_names.index((resolution, level))
		mid_point_pos = low_res_names[mid_point_idx][3]

		ys = []
		for name, resolution, level, position in low_res_names:
			x = position - mid_point_pos

			exponent = -x/blend
			y = 1 / (1 + math.exp(exponent))
			ys.append(y)

		tfutil.set_vars(
			tfutil.run(
				{result_model.vars[name]: (high_res_model.vars[name] * y + low_res_model.vars[name] * (1-y))
				for name, y
				in zip(full_names, ys)}
			)
		)

		return StyleGanWrapper().set_network((low_res_wrapper._G, low_res_wrapper._D, result_model))
Esempio n. 2
0
def copy_weights(src_net, tgt_net, vars_to_copy):
    names = [
        name for name in tgt_net.trainables.keys() if name in vars_to_copy
    ]
    tfutil.set_vars(
        tfutil.run({tgt_net.vars[name]: src_net.vars[name]
                    for name in names}))
Esempio n. 3
0
def copy_and_crop_or_pad_trainables(src_net, tgt_net) -> None:
    source_trainables = src_net.trainables.keys()
    target_trainables = tgt_net.trainables.keys()
    names = [pair for pair in zip(source_trainables, target_trainables)]

    skip = []
    pbar = ProgressBar(len(names))
    for pair in names:
        source_name, target_name = pair
        log = source_name
        x = src_net.get_var(source_name)
        y = tgt_net.get_var(target_name)
        source_shape = x.shape
        target_shape = y.shape
        if source_shape != target_shape:
            update = x
            index = None
            if 'Dense' in source_name:
                if source_shape[0] > target_shape[0]:
                    gap = source_shape[0] - target_shape[0]
                    start = abs(gap) // 2
                    end = start + target_shape[0]
                    update = update[start:end, :]
                else:
                    update = pad_symm_np(update, target_shape)
                    log = (log, source_shape, '=>', target_shape)
            else:
                try:
                    if source_shape[2] > target_shape[2]:
                        index = 2
                        gap = source_shape[index] - target_shape[index]
                        start = abs(gap) // 2
                        end = start + target_shape[index]
                        update = update[:, :, start:end, :]
                    if source_shape[3] > target_shape[3]:
                        index = 3
                        gap = source_shape[index] - target_shape[index]
                        start = abs(gap) // 2
                        end = start + target_shape[index]
                        update = update[:, :, :, start:end]
                except:
                    print(' Wrong var pair?', source_name, source_shape,
                          target_name, target_shape)
                    exit(1)

                if source_shape[2] < target_shape[2] or source_shape[
                        3] < target_shape[3]:
                    update = pad_symm_np(update, target_shape[2:])
                    log = (log, source_shape, '=>', target_shape)
                    # print(pair, source_shape, target_shape)

            tgt_net.set_var(target_name, update)
            skip.append(source_name)
        pbar.upd(pair)

    weights_to_copy = {
        tgt_net.vars[pair[1]]: src_net.vars[pair[0]]
        for pair in names if pair[0] not in skip
    }
    tfutil.set_vars(tfutil.run(weights_to_copy))
Esempio n. 4
0
 def copy_own_vars_from(self, src_net: "Network") -> None:
     """Copy the values of all variables from the given network, excluding sub-networks."""
     names = [
         name for name in self.own_vars.keys() if name in src_net.own_vars
     ]
     tfutil.set_vars(
         tfutil.run({self.vars[name]: src_net.vars[name]
                     for name in names}))
Esempio n. 5
0
def apply_denominator(dst_net, denominator):
    denominator_inv = 1.0 / denominator
    names = [name for name in dst_net.trainables.keys()]
    tfutil.set_vars(
        tfutil.run({
            dst_net.vars[name]: dst_net.vars[name] * denominator_inv
            for name in names
        }))
    return dst_net
Esempio n. 6
0
def blend_models(model_1,
                 model_2,
                 resolution_h,
                 resolution_w,
                 res_log2,
                 min_h,
                 min_w,
                 level,
                 blend_width=None,
                 verbose=False):

    # y is the blending amount which y = 0 means all model 1, y = 1 means all model_2

    # TODO add small x offset for smoother blend animations
    resolution = f"{resolution_h}x{resolution_w}"

    model_1_names = extract_conv_names(model_1, res_log2, min_h, min_w)
    model_2_names = extract_conv_names(model_2, res_log2, min_h, min_w)

    assert all((x == y for x, y in zip(model_1_names, model_2_names)))

    model_out = model_1.clone()

    short_names = [(x[1:3]) for x in model_1_names]
    full_names = [(x[0]) for x in model_1_names]
    mid_point_idx = short_names.index((resolution, level))
    mid_point_pos = model_1_names[mid_point_idx][3]

    ys = []
    for name, resolution, level, position in model_1_names:
        # low to high (res)
        x = position - mid_point_pos
        if blend_width:
            exponent = -x / blend_width
            y = 1 / (1 + math.exp(exponent))
        else:
            y = 1 if x > 1 else 0

        ys.append(y)
        if verbose:
            print(f"Blending {name} by {y}")

    tfutil.set_vars(
        tfutil.run({
            model_out.vars[name]:
            (model_2.vars[name] * y + model_1.vars[name] * (1 - y))
            for name, y in zip(full_names, ys)
        }))

    return model_out
Esempio n. 7
0
def copy_vars(src_net, tgt_net, D=False):
    names = [
        name for name in tgt_net.trainables.keys()
        if name in src_net.trainables.keys()
    ]
    var_dict = OrderedDict()

    for name in names:
        if tgt_net.vars[name].shape == src_net.vars[
                name].shape:  # fixing rgb-to-rgba only !!
            var_dict[name] = src_net.vars[name]
        else:
            var_dict[name] = add_channel(src_net.vars[name], D=D)

    weights_to_copy = {tgt_net.vars[name]: var_dict[name] for name in names}
    tfutil.set_vars(tfutil.run(weights_to_copy))
def copy_compatible_trainables_from(dst_net, src_net) -> None:
    """Copy the compatible values of all trainable variables from the given network, including sub-networks"""
    names = []
    for name in dst_net.trainables.keys():
        if name not in src_net.trainables:
            print("Not restoring (not present):     {}".format(name))
        elif dst_net.trainables[name].shape != src_net.trainables[name].shape:
            print("Not restoring (different shape): {}".format(name))
        elif name in src_net.trainables and dst_net.trainables[
                name].shape == src_net.trainables[name].shape:
            print("Restoring: {}".format(name))
            names.append(name)

    tfutil.set_vars(
        tfutil.run({dst_net.vars[name]: src_net.vars[name]
                    for name in names}))
Esempio n. 9
0
def slow_blend_from_saved_weights(net, alpha=0.9, verbose=False):
    if verbose: print("first net.vars", net.vars)
    #print("second net.vars", second_net.vars)

    for tensor_key in net.vars:
        blended_dicts = {}

        #if "G_synthesis" in tensor_key and (("_up" not in tensor_key) or ("ToRGB" in tensor_key)):
        #if "G_synthesis" in tensor_key and "1024x1024" in tensor_key:
        if "G_synthesis" in tensor_key:
            # so far it seems like for StyleGan2 we ...
            # - need ToRGB, need both the ones without _up and with _up
            try:
                if verbose: print(tensor_key)

                first_net_weights = weightsfirstnet[
                    tensor_key]  # net.get_var(tensor_key)
                second_net_weights = weights[
                    tensor_key]  # second_net.get_var(tensor_key)

                blended_weights = (1.0 - alpha) * first_net_weights + (
                    alpha) * second_net_weights
                blended_weights = np.copy(blended_weights)

                #net.set_var(tensor_key, np.copy( blended_weights ))
                v = net.find_var(tensor_key)
                tfutil.set_vars({v: blended_weights})

                #blended_dicts[tensor_key] = np.copy( blended_weights )
                """
                start = timer()
                end = timer()
                time = (end - start)
                print("Save to net " + str(time) + "s")
                """
            except Exception as e:
                print("--failed on tensor", tensor_key, "with:", e)
    """
    blended_dicts_tmp = {}
    for key, value in blended_dicts:
        k = net.find_var(key)
        blended_dicts_tmp[k] = value
        tfutil.set_vars({k: value}) # < try without too
    #tfutil.set_vars(blended_dicts_tmp)
    """

    return net
Esempio n. 10
0
def main():
    os.makedirs(a.out_dir, exist_ok=True)

    tflib.init_tf()
    with tf.Session() as sess, tf.device('/gpu:0'):
        Gs_lo = load_pkl(a.pkl1)
        Gs_hi = load_pkl(a.pkl2)

        # TODO add small x offset for smoother blend animations
        resolution = "{}x{}".format(a.res, a.res)
        
        model_1_names = extract_conv_names(Gs_lo)
        model_2_names = extract_conv_names(Gs_hi)
        assert all((x == y for x, y in zip(model_1_names, model_2_names)))

        Gs_out = Gs_lo.clone()
        
        short_names = [(x[1:3]) for x in model_1_names]
        full_names = [(x[0]) for x in model_1_names]
        mid_point_idx = short_names.index((resolution, a.level))
        mid_point_pos = model_1_names[mid_point_idx][3]
        
        ys = []
        for name, resolution, level, position in model_1_names:
            x = position - mid_point_pos
            if a.blend_width is not None:
                exponent = -x / a.blend_width
                y = 1 / (1 + math.exp(exponent))
            else:
                y = 1 if x > 1 else 0
            ys.append(y)
            if a.verbose is True:
                print("Blending {} by {}".format(name, y))

        tfutil.set_vars(tfutil.run({ 
                 Gs_out.vars[name]: (Gs_hi.vars[name] * y + Gs_lo.vars[name] * (1-y))
                 for name, y in zip(full_names, ys)} ))

        out_name = os.path.join(a.out_dir, '%s-%s-%d-%d' % (basename(a.pkl1).split('-')[0], basename(a.pkl2).split('-')[0], a.res, a.level))  
        save_pkl(Gs_out, '%s.pkl' % out_name)
            
        rnd = np.random.RandomState(696)
        grid_latents = rnd.randn(4, *Gs_lo.input_shape[1:])
        grid_fakes = Gs_out.run(grid_latents, [None], is_validation=True, minibatch_size=1)
        grid_fakes = np.hstack(np.transpose(grid_fakes, [0,2,3,1]))
        imsave('%s.jpg' % out_name, ((grid_fakes+1)*127.5).astype(np.uint8))
Esempio n. 11
0
def add_networks(dst_net, src_net):
    names = []
    for name in dst_net.trainables.keys():
        if name not in src_net.trainables:
            print('Not restoring (not present):     {}'.format(name))
        elif dst_net.trainables[name].shape != src_net.trainables[name].shape:
            print('Not restoring (different shape): {}'.format(name))

        if name in src_net.trainables and dst_net.trainables[
                name].shape == src_net.trainables[name].shape:
            names.append(name)

    tfutil.set_vars(
        tfutil.run({
            dst_net.vars[name]: dst_net.vars[name] + src_net.vars[name]
            for name in names
        }))
    return dst_net
def copy_and_crop_trainables_from(target_net, source_net) -> None:
    source_trainables = source_net.trainables.keys()
    target_trainables = target_net.trainables.keys()
    names = [pair for pair in zip(source_trainables, target_trainables)]

    skip = []
    for pair in names:
        source_name, target_name = pair
        x = source_net.get_var(source_name)
        y = target_net.get_var(target_name)
        source_shape = x.shape
        target_shape = y.shape
        if source_shape != target_shape:
            update = x
            index = None
            if 'Dense' in source_name:
                index = 0
                gap = source_shape[index] - target_shape[index]
                start = abs(gap) // 2
                end = start + target_shape[index]
                update = update[start:end, :]
            else:
                if source_shape[2] != target_shape[2]:
                    index = 2
                    gap = source_shape[index] - target_shape[index]
                    start = abs(gap) // 2
                    end = start + target_shape[index]
                    update = update[:, :, start:end, :]
                if source_shape[3] != target_shape[3]:
                    index = 3
                    gap = source_shape[index] - target_shape[index]
                    start = abs(gap) // 2
                    end = start + target_shape[index]
                    update = update[:, :, :, start:end]

            target_net.set_var(target_name, update)
            skip.append(source_name)

    weights_to_copy = {
        target_net.vars[pair[1]]: source_net.vars[pair[0]]
        for pair in names if pair[0] not in skip
    }
    tfutil.set_vars(tfutil.run(weights_to_copy))
Esempio n. 13
0
def copy_and_fill_trainables(src_net, tgt_net) -> None:  # model => conditional
    train_vars = [
        name for name in src_net.trainables.keys()
        if name in tgt_net.trainables.keys()
    ]
    skip = []
    pbar = ProgressBar(len(train_vars))
    for name in train_vars:
        x = src_net.get_var(name)
        y = tgt_net.get_var(name)
        src_shape = x.shape
        tgt_shape = y.shape
        if src_shape != tgt_shape:
            assert len(src_shape) == len(
                tgt_shape), "Different shapes: %s %s" % (str(src_shape),
                                                         str(tgt_shape))
            if np.less(
                    tgt_shape,
                    src_shape).any():  # kill labels: [1024,512] => [512,512]
                try:
                    update = x[:tgt_shape[0], :tgt_shape[1],
                               ...]  # !!! corrects only first two dims
                except:
                    update = x[:tgt_shape[0]]
            elif np.greater(
                    tgt_shape,
                    src_shape).any():  # add labels: [512,512] => [1024,512]
                tile_count = [
                    tgt_shape[i] // src_shape[i] for i in range(len(src_shape))
                ]
                if a.verbose is True:
                    print(name, tile_count, src_shape, '=>', tgt_shape,
                          '\n\n')  # G_mapping/Dense0, D/Output
                update = np.tile(x, tile_count)
            tgt_net.set_var(name, update)
            skip.append(name)
        pbar.upd(name)
    weights_to_copy = {
        tgt_net.vars[name]: src_net.vars[name]
        for name in train_vars if name not in skip
    }
    tfutil.set_vars(tfutil.run(weights_to_copy))
Esempio n. 14
0
def blend_layers(Net_lo, Net_hi, type='G'):
    print(' blending', type)
    resolution = "{}x{}".format(a.res, a.res)

    model_1_names = extract_conv_names(Net_lo, type)
    model_2_names = extract_conv_names(Net_hi, type)
    assert all((x == y for x, y in zip(model_1_names, model_2_names)))

    Net_out = Net_lo.clone()

    short_names = [(x[1:3]) for x in model_1_names]
    full_names = [(x[0]) for x in model_1_names]
    mid_point_idx = short_names.index((resolution, a.level))
    mid_point_pos = model_1_names[mid_point_idx][3]
    print(' boundary ::', mid_point_idx, mid_point_pos,
          model_1_names[mid_point_idx])

    ys = []
    for name, resolution, level, position in model_1_names:
        # print(name, resolution, level, position)
        # add small x offset for smoother blend animations ?
        x = position - mid_point_pos
        if a.blend_width is not None:
            exponent = -x / a.blend_width
            y = 1 / (1 + math.exp(exponent))
        else:
            y = 1 if x > 1 else 0
        ys.append(y)
        if a.verbose and y > 0:
            print(" .. {} *{}".format(name, y))

    tfutil.set_vars(
        tfutil.run({
            Net_out.vars[name]:
            (Net_hi.vars[name] * y + Net_lo.vars[name] * (1 - y))
            for name, y in zip(full_names, ys)
        }))
    return Net_out
Esempio n. 15
0
    def __setstate__(self, state: dict) -> None:
        """Pickle import."""
        # pylint: disable=attribute-defined-outside-init
        tfutil.assert_tf_initialized()
        self._init_fields()

        # Execute custom import handlers.
        for handler in _import_handlers:
            state = handler(state)

        # Set basic fields.
        assert state["version"] in [2, 3]
        self.name = state["name"]
        self.static_kwargs = util.EasyDict(state["static_kwargs"])
        self.components = util.EasyDict(state.get("components", {}))
        self._build_module_src = state["build_module_src"]
        self._build_func_name = state["build_func_name"]

        # Create temporary module from the imported source code.
        module_name = "_tflib_network_import_" + uuid.uuid4().hex
        module = types.ModuleType(module_name)
        sys.modules[module_name] = module
        _import_module_src[module] = self._build_module_src
        exec(self._build_module_src, module.__dict__)  # pylint: disable=exec-used

        # Locate network build function in the temporary module.
        self._build_func = util.get_obj_from_module(module,
                                                    self._build_func_name)
        assert callable(self._build_func)

        # Init TensorFlow graph.
        self._init_graph()
        self.reset_own_vars()
        tfutil.set_vars(
            {self.find_var(name): value
             for name, value in state["variables"]})
Esempio n. 16
0
 def set_var(self, var_or_local_name: Union[TfExpression, str],
             new_value: Union[int, float, np.ndarray]) -> None:
     """Set the value of a given variable based on the given NumPy array.
     Note: This method is very inefficient -- prefer to use tflib.set_vars() whenever possible."""
     tfutil.set_vars({self.find_var(var_or_local_name): new_value})