Ejemplo n.º 1
0
    def reconstruct_images(self,
                           save_file,
                           sess,
                           x,
                           block_shape,
                           show_original_images=True,
                           batch_size=20,
                           dec_output_2_img_func=None,
                           **kwargs):
        if batch_size < 0:
            x1 = self.reconstruct(sess, x, **kwargs)
        else:
            x1 = []
            for batch_ids in iterate_data(len(x), batch_size, shuffle=False):
                x1.append(self.reconstruct(sess, x[batch_ids], **kwargs))

            x1 = np.concatenate(x1, axis=0)

        if dec_output_2_img_func is not None:
            x1 = dec_output_2_img_func(x1)
            x = dec_output_2_img_func(x)

        x1 = np.reshape(x1, to_list(block_shape) + self.x_shape)
        x = np.reshape(x, to_list(block_shape) + self.x_shape)

        if show_original_images:
            save_img_blocks_col_by_col(save_file, [x, x1])
        else:
            save_img_block(save_file, x1)
Ejemplo n.º 2
0
    def generate_images(self,
                        save_file,
                        sess,
                        z,
                        block_shape,
                        batch_size=20,
                        dec_output_2_img_func=None,
                        **kwargs):
        if batch_size < 0:
            x1_gen = self.decode(sess, z, **kwargs)
        else:
            x1_gen = []
            for batch_ids in iterate_data(len(z), batch_size, shuffle=False):
                x1_gen.append(self.decode(sess, z[batch_ids], **kwargs))

            x1_gen = np.concatenate(x1_gen, axis=0)

        if dec_output_2_img_func is not None:
            x1_gen = dec_output_2_img_func(x1_gen)

        x1_gen = np.reshape(x1_gen, to_list(block_shape) + self.x_shape)
        save_img_block(save_file, x1_gen)
Ejemplo n.º 3
0
def main(args):
    # =====================================
    # Load config
    # =====================================
    with open(join(args.output_dir, 'config.json')) as f:
        config = json.load(f)
    args.__dict__.update(config)

    # =====================================
    # Dataset
    # =====================================
    celebA_loader = TFCelebALoader(root_dir=args.celebA_root_dir)

    img_height, img_width = args.celebA_resize_size, args.celebA_resize_size
    celebA_loader.build_transformation_flow_tf(
        *celebA_loader.get_transform_fns("1Konny",
                                         resize_size=args.celebA_resize_size))

    # =====================================
    # Instantiate model
    # =====================================
    if args.activation == "relu":
        activation = tf.nn.relu
    elif args.activation == "leaky_relu":
        activation = tf.nn.leaky_relu
    else:
        raise ValueError("Do not support '{}' activation!".format(
            args.activation))

    if args.enc_dec_model == "1Konny":
        assert args.z_dim == 65, "For 1Konny, z_dim must be 65. Found {}!".format(
            args.z_dim)

        encoder = Encoder_1Konny(args.z_dim,
                                 stochastic=True,
                                 activation=activation)
        decoder = Decoder_1Konny([img_height, img_width, 3],
                                 activation=activation,
                                 output_activation=tf.nn.sigmoid)
        disc_z = DiscriminatorZ_1Konny(num_outputs=2)
    else:
        raise ValueError("Do not support encoder/decoder model '{}'!".format(
            args.enc_dec_model))

    model = FactorVAE([img_height, img_width, 3],
                      args.z_dim,
                      encoder=encoder,
                      decoder=decoder,
                      discriminator_z=disc_z,
                      rec_x_mode=args.rec_x_mode,
                      use_gp0_z_tc=True,
                      gp0_z_tc_mode=args.gp0_z_tc_mode)

    loss_coeff_dict = {
        'rec_x': args.rec_x_coeff,
        'kld_loss': args.kld_loss_coeff,
        'tc_loss': args.tc_loss_coeff,
        'gp0_z_tc': args.gp0_z_tc_coeff,
    }

    model.build(loss_coeff_dict)
    SimpleParamPrinter.print_all_params_tf_slim()

    # =====================================
    # Load model
    # =====================================
    config_proto = tf.ConfigProto(allow_soft_placement=True)
    config_proto.gpu_options.allow_growth = True
    config_proto.gpu_options.per_process_gpu_memory_fraction = 0.9
    sess = tf.Session(config=config_proto)

    model_dir = make_dir_if_not_exist(join(args.output_dir, "model_tf"))
    train_helper = SimpleTrainHelper(log_dir=None, save_dir=model_dir)

    # Load model
    train_helper.load(sess, load_step=args.load_step)

    # =====================================
    # Experiments
    # =====================================

    # Reconstruct
    # ======================================= #
    seed = 341
    rs = np.random.RandomState(seed)
    ids = rs.choice(celebA_loader.num_test_data, size=15)

    x = celebA_loader.sample_images_from_dataset(sess, 'test', ids)

    save_dir = make_dir_if_not_exist(join(args.save_dir, args.run))

    img_file = join(save_dir, 'x_test.png')
    save_img_block(img_file, binary_float_to_uint8(np.expand_dims(x, axis=0)))

    img_file = join(save_dir, 'recx_test_1.png')
    model.reconstruct_images(img_file,
                             sess,
                             x,
                             block_shape=[1, len(ids)],
                             batch_size=-1,
                             show_original_images=False,
                             dec_output_2_img_func=binary_float_to_uint8)

    img_file = join(save_dir, 'recx_test_2.png')
    model.reconstruct_images(img_file,
                             sess,
                             x,
                             block_shape=[1, len(ids)],
                             batch_size=-1,
                             show_original_images=True,
                             dec_output_2_img_func=binary_float_to_uint8)
Ejemplo n.º 4
0
    def interpolate_images(self,
                           save_file,
                           sess,
                           x1,
                           x2,
                           num_itpl_points,
                           batch_on_row=True,
                           batch_size=20,
                           dec_output_2_img_func=None,
                           enc_kwargs={},
                           dec_kwargs={}):
        if batch_size < 0:
            z1 = self.encode(sess, x1, **enc_kwargs)
            z2 = self.encode(sess, x2, **enc_kwargs)
        else:
            z1, z2 = [], []
            for batch_ids in iterate_data(len(x1), batch_size, shuffle=False):
                z1.append(self.encode(sess, x1[batch_ids], **enc_kwargs))
                z2.append(self.encode(sess, x2[batch_ids], **enc_kwargs))

            z1 = np.concatenate(z1, axis=0)
            z2 = np.concatenate(z2, axis=0)

        z1_flat = np.ravel(z1)
        z2_flat = np.ravel(z2)

        zs_itpl = []
        for i in range(1, num_itpl_points + 1):
            zi_flat = z1_flat + (i * 1.0 /
                                 (num_itpl_points + 1)) * (z2_flat - z1_flat)
            zs_itpl.append(zi_flat)

        # (num_itpl_points, batch_size * z_dim)
        zs_itpl = np.stack(zs_itpl, axis=0)
        # (num_itpl_points * batch_size, z_shape)
        zs_itpl = np.reshape(zs_itpl,
                             [num_itpl_points * x1.shape[0]] + self.z_shape)

        if batch_size < 0:
            xs_itpl = self.decode(sess, zs_itpl, **dec_kwargs)
        else:
            xs_itpl = []
            for batch_ids in iterate_data(len(zs_itpl),
                                          batch_size,
                                          shuffle=False):
                xs_itpl.append(
                    self.decode(sess, zs_itpl[batch_ids], **dec_kwargs))

            xs_itpl = np.concatenate(xs_itpl, axis=0)

        # (num_itpl_points, batch_size, x_dim)
        xs_itpl = np.reshape(xs_itpl,
                             [num_itpl_points, x1.shape[0]] + self.x_shape)
        # (num_itpl_points + 2, batch_size, x_dim)
        xs_itpl = np.concatenate(
            [np.expand_dims(x1, axis=0), xs_itpl,
             np.expand_dims(x2, axis=0)],
            axis=0)

        if batch_on_row:
            xs_itpl = np.transpose(xs_itpl, [1, 0] +
                                   list(range(2,
                                              len(self.x_shape) + 2)))

        if dec_output_2_img_func is not None:
            xs_itpl = dec_output_2_img_func(xs_itpl)

        save_img_block(save_file, xs_itpl)
Ejemplo n.º 5
0
    def plot_Z_itpl_bw_2Xs(self,
                           save_file_prefix,
                           sess,
                           imgs_1,
                           imgs_2,
                           img_names_1,
                           img_names_2,
                           features,
                           num_itpl_points=6,
                           yx_types=('feature', 'itpl_point'),
                           dec_output_2_img_func=None,
                           img_ext='png',
                           batch_size=-1):

        # img_ext
        # ---------------------------------------- #
        assert img_ext == 'png' or img_ext == 'jpg', "'img_ext' must be png or jpg!"
        # ---------------------------------------- #

        # coordinate
        # ---------------------------------------- #
        # For this kind of interpolation, the results will have 3 axes:
        # (num_inputs, num_features, num_itpl_points)

        # If we set mode == 'share_inputs', we will have 'num_inputs' block images
        # of shape (num_features, num_itpl_points)
        possible_coord_types = [('input', 'itpl_point'),
                                ('feature', 'itpl_point'),
                                ('itpl_point', 'input'),
                                ('itpl_point', 'feature')]

        if isinstance(yx_types, tuple):
            assert len(yx_types) == 2, "'yx_types' must be a 2-tuples or " \
                                       "a list of 2-tuples representing the yx coordinate types!"
            yx_types = [yx_types]

        assert isinstance(yx_types, list), "'yx_types' must be a 2-tuples or " \
                                           "a list of 2-tuples representing the yx coordinate types!"

        assert all([yx_type in possible_coord_types for yx_type in yx_types]), \
            "Only support the following coordinate types: {}".format(possible_coord_types)
        # ---------------------------------------- #

        # num_images
        # ---------------------------------------- #
        assert isinstance(imgs_1, np.ndarray) and imgs_1.ndim == 4, \
            "'inp_imgs_1' must be a 4D numpy array of format (num_images, height, width, channels)!"

        assert isinstance(imgs_2, np.ndarray) and imgs_2.ndim == 4, \
            "'inp_imgs_2' must be None or a 4D numpy array of format (num_images, height, width, channels)!"

        assert len(imgs_1) == len(
            imgs_2
        ), "Number of images in 'inp_imgs_1' and 'inp_imgs_2' must be equal!"
        num_inputs = len(imgs_1)
        # ---------------------------------------- #

        # num_features
        # ---------------------------------------- #
        z_dim = int(np.prod(self.z_shape))
        if features == 'all':
            features = [i for i in range(z_dim)]

        if isinstance(features, int):
            assert 0 <= features < z_dim, "'features' must be an integer or " \
                "a list/tuple of integers in the range [0, {}]".format(z_dim - 1)
            features = [features]

        assert isinstance(features, (list, tuple)), "'features' must be an integer or " \
            "a list/tuple of integers in the range [0, {}]".format(z_dim - 1)

        num_features = len(features)
        # ---------------------------------------- #

        # (num_images, z_dim)
        z1 = np.reshape(self.encode(sess, imgs_1), [num_inputs, z_dim])
        z2 = np.reshape(self.encode(sess, imgs_2), [num_inputs, z_dim])

        z_samples = [
        ]  # (num_features * num_itpl_points) of (num_images, z_dim) array

        for n in range(len(imgs_1)):
            for feature in features:
                # (num_itpl_points, )
                itpl_points = np.linspace(z1[n, feature],
                                          z2[n, feature],
                                          num=num_itpl_points,
                                          endpoint=True)

                for itpl_point in itpl_points:
                    z_copy = np.array(z1[n], dtype=z1.dtype, copy=True)
                    z_copy[feature] = itpl_point

                    z_samples.append(z_copy)

        # (num_inputs * num_features * num_itpl_points, z_dim)
        z_samples = np.stack(z_samples, axis=0)

        if batch_size < 0:
            z_samples = np.reshape(
                z_samples,
                [num_inputs * num_features * num_itpl_points] + self.z_shape)
            x_samples = self.decode(sess, z_samples)
        else:
            x_samples = []
            for batch_ids in iterate_data(len(z_samples),
                                          batch_size,
                                          shuffle=False):
                x_samples.append(
                    self.decode(
                        sess,
                        np.reshape(z_samples[batch_ids],
                                   [len(batch_ids)] + self.z_shape)))
            x_samples = np.concatenate(x_samples, axis=0)

        # (num_images, num_features, num_itpl_points) + x_shape
        x_samples = np.reshape(x_samples,
                               [num_inputs, num_features, num_itpl_points] +
                               self.x_shape)
        if dec_output_2_img_func is not None:
            x_samples = dec_output_2_img_func(x_samples)

        for yx_type in yx_types:
            if yx_type == ('feature', 'itpl_point'):
                x_itpl = x_samples
                assert img_names_1, "'inp_img_names_1' must be provided!"
                assert img_names_2, "'inp_img_names_2' must be provided!"
                save_file_postfixes = [
                    "-img[{}-{}].{}".format(img_names_1[i], img_names_2[i],
                                            img_ext)
                    for i in range(len(x_itpl))
                ]

            elif yx_type == ('itpl_point', 'feature'):
                x_itpl = np.transpose(x_samples, [0, 2, 1] +
                                      list(range(3, 3 + len(self.x_shape))))
                assert img_names_1, "'inp_img_names_1' must be provided!"
                assert img_names_2, "'inp_img_names_2' must be provided!"
                save_file_postfixes = [
                    "-img[{}-{}].{}".format(img_names_1[i], img_names_2[i],
                                            img_ext)
                    for i in range(len(x_itpl))
                ]

            elif yx_type == ('input', 'itpl_point'):
                x_itpl = np.transpose(x_samples, [1, 0, 2] +
                                      list(range(3, 3 + len(self.x_shape))))
                save_file_postfixes = [
                    "-feat[{}].{}".format(feature, img_ext)
                    for feature in features
                ]

            elif yx_type == ('itpl_point', 'input'):
                x_itpl = np.transpose(x_samples, [1, 2, 0] +
                                      list(range(3, 3 + len(self.x_shape))))
                save_file_postfixes = [
                    "-feat[{}].{}".format(feature, img_ext)
                    for feature in features
                ]

            elif yx_type == (None, 'itpl_point'):
                # (num_images, num_features, num_itpl_points) + x_shape
                x_itpl = np.reshape(
                    x_samples,
                    [num_inputs * num_features, 1, num_itpl_points] +
                    self.x_shape)
                save_file_postfixes = [
                    "-img[{}-{}]_feat[{}].{}".format(img_name_1, img_name_2,
                                                     feature, img_ext)
                    for img_name_1, img_name_2 in zip(img_names_1, img_name_2)
                    for feature in features
                ]

            elif yx_type == ('itpl_point', None):
                # (num_images, num_features, num_itpl_points) + x_shape
                x_itpl = np.reshape(
                    x_samples,
                    [num_inputs * num_features, num_itpl_points, 1] +
                    self.x_shape)
                save_file_postfixes = [
                    "-img[{}-{}]_feat[{}].{}".format(img_name_1, img_name_2,
                                                     feature, img_ext)
                    for img_name_1, img_name_2 in zip(img_names_1, img_name_2)
                    for feature in features
                ]

            else:
                raise ValueError(
                    "Only support the following coordinate types: {}".format(
                        possible_coord_types))

            for i in range(len(x_itpl)):
                save_file = save_file_prefix + save_file_postfixes[i]
                save_img_block(save_file, x_itpl[i])
Ejemplo n.º 6
0
    def cond_all_latents_traverse_v2(
            self,
            save_file,
            sess,
            x,
            z_comps=None,
            z_comp_labels=None,
            span=2,
            points_1_side=6,
            # substitute with original x and highlight
            hl_x=True,
            hl_color="red",
            hl_width=1,
            font_size=12,
            title="",
            title_font_scale=1.5,
            subplot_adjust={},
            size_inches=None,
            batch_size=20,
            dec_output_2_img_func=None,
            enc_kwargs={},
            dec_kwargs={}):

        assert np.shape(x) == tuple(
            self.x_shape), "'x' must be a single instance!"
        # (1, x_dim)
        x_ = np.expand_dims(x, axis=0)

        # Compute z
        # ----------------------------- #
        # (1, z_dim)
        z = self.encode(sess, x_, **enc_kwargs)
        assert z.shape[0] == 1

        # (z_dim, )
        z_dim = int(np.prod(self.z_shape))
        z = np.reshape(z, [z_dim])
        # ----------------------------- #

        if z_comps is None:
            z_comps = list(range(z_dim))

        z_meshgrid = []
        inserted_ids = []
        s = span
        p = points_1_side

        for i, comp in enumerate(z_comps):
            # (2 * points_1_side + 1, )
            itpl_vals = [(z[comp] - s) + 1.0 * i * s / p for i in range(p)]
            itpl_vals += [z[comp]]
            itpl_vals += [z[comp] + 1.0 * i * s / p for i in range(1, p + 1)]

            for val in itpl_vals:
                z_copy = np.array(z, dtype=z.dtype, copy=True)
                z_copy[comp] = val
                z_meshgrid.append(z_copy)

            inserted_ids.append((i, points_1_side))

        # Compute z meshgrid
        # ----------------------------- #
        num_rows = len(z_comps)
        num_cols = 2 * points_1_side + 1
        assert len(z_meshgrid) == num_rows * num_cols

        z_meshgrid = np.reshape(z_meshgrid,
                                [num_rows * num_cols] + self.z_shape)
        # ----------------------------- #

        # Reconstruct x meshgrid
        # ----------------------------- #
        if batch_size < 0:
            x_meshgrid = self.decode(sess, z_meshgrid, **dec_kwargs)
        else:
            x_meshgrid = []
            for batch_ids in iterate_data(len(z_meshgrid),
                                          batch_size,
                                          shuffle=False):
                x_meshgrid.append(
                    self.decode(sess, z_meshgrid[batch_ids], **dec_kwargs))

            x_meshgrid = np.concatenate(x_meshgrid, axis=0)

        x_meshgrid = np.reshape(x_meshgrid,
                                [num_rows, num_cols] + self.x_shape)

        if hl_x:
            for row_idx, col_idx in inserted_ids:
                x_meshgrid[row_idx, col_idx] = x

        if dec_output_2_img_func is not None:
            x_meshgrid = dec_output_2_img_func(x_meshgrid)

        if z_comp_labels is not None:
            assert len(z_comp_labels) == len(z_comps), \
                "Length of 'z_comp_labels' must be equal to the number of z components " \
                "you want to draw. Found {} and {}, respectively!".format(len(z_comp_labels), len(z_comps))

            if hl_x:
                save_img_block_highlighted_with_ticklabels(
                    save_file,
                    x_meshgrid,
                    hl_blocks=inserted_ids,
                    hl_color=hl_color,
                    hl_width=hl_width,
                    x_tick_labels=None,
                    y_tick_labels=z_comp_labels,
                    font_size=font_size,
                    title=title,
                    title_font_scale=title_font_scale,
                    subplot_adjust=subplot_adjust,
                    size_inches=size_inches)
            else:
                save_img_block_with_ticklabels(
                    save_file,
                    x_meshgrid,
                    x_tick_labels=None,
                    y_tick_labels=z_comp_labels,
                    font_size=font_size,
                    title=title,
                    title_font_scale=title_font_scale,
                    subplot_adjust=subplot_adjust,
                    size_inches=size_inches)
        else:
            if hl_x:
                save_img_block_highlighted(save_file,
                                           x_meshgrid,
                                           hl_blocks=inserted_ids,
                                           hl_color=hl_color,
                                           hl_width=hl_width)
            else:
                save_img_block(save_file, x_meshgrid)
Ejemplo n.º 7
0
    def rand_2_latents_traverse(self,
                                save_file,
                                sess,
                                default_z,
                                z_comp1,
                                start1,
                                stop1,
                                num_points1,
                                z_comp2,
                                start2,
                                stop2,
                                num_points2,
                                batch_size=20,
                                dec_output_2_img_func=None,
                                **kwargs):
        """
        default_z: A single latent code to serve as default
        z_comp1: z component 1
        z_limits1: 2-tuple specifying the low-high value of z_comp1
        num_points1: Number of points
        z_comp2:
        z_limits2:
        num_points2:
        """
        assert num_points1 >= 2, "'num_points1' must be >=2. Found {}!".format(
            num_points1)
        assert num_points2 >= 2, "'num_points2' must be >=2. Found {}!".format(
            num_points2)

        z_range1 = [
            start1 + (stop1 - start1) * i * 1.0 / (num_points1 - 1)
            for i in range(num_points1)
        ]
        z_range2 = [
            start2 + (stop2 - start2) * i * 1.0 / (num_points2 - 1)
            for i in range(num_points2)
        ]

        num_rows = len(z_range1)
        num_cols = len(z_range2)

        assert np.shape(default_z) == tuple(
            self.z_shape), "'default_z' must be a single instance!"

        default_z = np.reshape(default_z, [int(np.prod(self.z_shape))])
        z_meshgrid = np.tile(np.expand_dims(default_z, axis=0),
                             [num_rows * num_cols, 1])

        for m in range(num_rows):
            for n in range(num_cols):
                z_meshgrid[m * num_cols + n, z_comp1] = z_range1[m]
                z_meshgrid[m * num_cols + n, z_comp2] = z_range2[n]

        # Reconstruct x meshgrid
        # ----------------------------- #
        if batch_size < 0:
            x_meshgrid = self.decode(sess, z_meshgrid, **kwargs)
        else:
            x_meshgrid = []
            for batch_ids in iterate_data(len(z_meshgrid),
                                          batch_size,
                                          shuffle=False):
                x_meshgrid.append(
                    self.decode(sess, z_meshgrid[batch_ids], **kwargs))

            x_meshgrid = np.concatenate(x_meshgrid, axis=0)

        x_meshgrid = np.reshape(x_meshgrid,
                                [num_rows, num_cols] + self.x_shape)

        if dec_output_2_img_func is not None:
            x_meshgrid = dec_output_2_img_func(x_meshgrid)

        save_img_block(save_file, x_meshgrid)