Ejemplo n.º 1
0
    def sample(self, data, mgen):
        if self.num_channels == 3:
            data = np.cast[np.float32]((data - 127.5) / 127.5)
        ori_data = data.copy()
        ds = np.split(data.copy(), self.nr_gpu)
        feed_dict = {}
        feed_dict.update({model.is_training: False for model in self.models})
        feed_dict.update({model.dropout_p: 0.0 for model in self.models})
        feed_dict.update({model.x: ds[i] for i, model in enumerate(self.models)})
        feed_dict.update({model.x_bar: ds[i] for i, model in enumerate(self.models)})

        masks_np = [mgen.gen(self.batch_size//self.nr_gpu) for i in range(self.nr_gpu)]
        feed_dict.update({model.masks: masks_np[i] for i, model in enumerate(self.models)})
        feed_dict.update({model.input_masks: masks_np[i] for i, model in enumerate(self.models)})

        ret = self.sess.run([model.z_mu for model in self.models]+[model.z_log_sigma_sq for model in self.models], feed_dict=feed_dict)
        z_mu = np.concatenate(ret[:len(ret)//2], axis=0)
        z_log_sigma_sq = np.concatenate(ret[len(ret)//2:], axis=0)
        z_sigma = np.sqrt(np.exp(z_log_sigma_sq))
        z = np.random.normal(loc=z_mu, scale=z_sigma)
        z = np.split(z, self.nr_gpu)
        feed_dict.update({model.z_ph:z[i] for i, model in enumerate(self.models)})
        feed_dict.update({model.use_z_ph: True for model in self.models})

        for i in range(self.nr_gpu):
            ds[i] *= broadcast_masks_np(masks_np[i], num_channels=self.num_channels)
        masked_data = np.concatenate(ds, axis=0)
        x_gen = [ds[i].copy() for i in range(self.nr_gpu)]
        for yi in range(self.img_size):
            for xi in range(self.img_size):
                if np.min(np.array([masks_np[i][:, yi, xi] for i in range(self.nr_gpu)])) > 0:
                    continue
                feed_dict.update({model.x_bar:x_gen[i] for i, model in enumerate(self.models)})
                x_hats = self.sess.run([model.x_hat for model in self.models], feed_dict=feed_dict)
                for i in range(self.nr_gpu):
                    bmask = broadcast_masks_np(masks_np[i][:, yi, xi] , num_channels=self.num_channels)
                    x_gen[i][:, yi, xi, :] = x_hats[i][:, yi, xi, :] * (1.-bmask) + x_gen[i][:, yi, xi, :] * bmask
        gen_data = np.concatenate(x_gen, axis=0)
        return ori_data, masked_data, gen_data
    def _examine_posterior(self, image, mgen, masks_np=None):
        data = np.array([image for i in range(self.batch_size)])
        if self.num_channels == 3:
            data = np.cast[np.float32]((data - 127.5) / 127.5)
        ori_data = data.copy()
        ds = np.split(data.copy(), self.nr_gpu)
        feed_dict = {}
        feed_dict.update({model.is_training: False for model in self.models})
        feed_dict.update({model.dropout_p: 0.0 for model in self.models})
        feed_dict.update(
            {model.x: ds[i]
             for i, model in enumerate(self.models)})
        feed_dict.update(
            {model.x_bar: ds[i]
             for i, model in enumerate(self.models)})
        if masks_np is None:
            masks_np = [
                mgen.gen(self.batch_size // self.nr_gpu)
                for i in range(self.nr_gpu)
            ]
        if self.phase == 'pvae':
            feed_dict.update({
                model.masks: np.zeros_like(masks_np[i])
                for i, model in enumerate(self.models)
            })
        elif self.phase == 'ce':
            feed_dict.update({
                model.masks: masks_np[i]
                for i, model in enumerate(self.models)
            })
        feed_dict.update({
            model.input_masks: masks_np[i]
            for i, model in enumerate(self.models)
        })

        ret = self.sess.run([model.z_mu for model in self.models] +
                            [model.z_log_sigma_sq for model in self.models],
                            feed_dict=feed_dict)
        z_mu = np.concatenate(ret[:len(ret) // 2], axis=0)
        z_log_sigma_sq = np.concatenate(ret[len(ret) // 2:], axis=0)
        z_sigma = np.sqrt(np.exp(z_log_sigma_sq))
        z = np.random.normal(loc=z_mu, scale=z_sigma)

        for i in range(self.nr_gpu):
            ds[i] *= broadcast_masks_np(masks_np[i],
                                        num_channels=self.num_channels)
        masked_data = np.concatenate(ds, axis=0)

        return z_mu, z_sigma, ori_data, masked_data, masks_np
    def sample(self, data, mgen, same_inputs=False, use_mask_at=None):
        if self.num_channels == 3:
            data = np.cast[np.float32]((data - 127.5) / 127.5)
        if same_inputs:
            for i in range(data.shape[0]):
                data[i] = data[3]
        ori_data = data.copy()
        ds = np.split(data.copy(), self.nr_gpu)
        feed_dict = {}
        feed_dict.update({model.is_training: False for model in self.models})
        feed_dict.update({model.dropout_p: 0.0 for model in self.models})
        feed_dict.update(
            {model.x: ds[i]
             for i, model in enumerate(self.models)})
        feed_dict.update(
            {model.x_bar: ds[i]
             for i, model in enumerate(self.models)})

        if use_mask_at is not None:
            masks_np = np.load(use_mask_at)['masks']
            masks_np = np.split(masks_np, self.nr_gpu)
        else:
            masks_np = [
                mgen.gen(self.batch_size // self.nr_gpu)
                for i in range(self.nr_gpu)
            ]
            np.savez(mgen.name + "_" + self.data_set,
                     masks=np.concatenate(masks_np))

        if same_inputs:
            for g in range(self.nr_gpu):
                for i in range(self.batch_size // self.nr_gpu):
                    masks_np[g][i] = masks_np[0][0]
        feed_dict.update(
            {model.masks: masks_np[i]
             for i, model in enumerate(self.models)})
        #
        for i in range(self.nr_gpu):
            ds[i] *= broadcast_masks_np(masks_np[i],
                                        num_channels=self.num_channels)
        masked_data = np.concatenate(ds, axis=0)
        x_gen = [ds[i].copy() for i in range(self.nr_gpu)]
        for yi in range(self.img_size):
            for xi in range(self.img_size):
                if np.min(
                        np.array([
                            masks_np[i][:, yi, xi] for i in range(self.nr_gpu)
                        ])) > 0:
                    continue
                feed_dict.update({
                    model.x_bar: x_gen[i]
                    for i, model in enumerate(self.models)
                })
                x_hats = self.sess.run([model.x_hat for model in self.models],
                                       feed_dict=feed_dict)
                for i in range(self.nr_gpu):
                    bmask = broadcast_masks_np(masks_np[i][:, yi, xi],
                                               num_channels=self.num_channels)
                    x_gen[i][:, yi, xi, :] = x_hats[i][:, yi, xi, :] * (
                        1. - bmask) + x_gen[i][:, yi, xi, :] * bmask
        gen_data = np.concatenate(x_gen, axis=0)
        if self.num_channels == 1:
            masks_np = np.concatenate(masks_np, axis=0)
            masks_np = broadcast_masks_np(masks_np,
                                          num_channels=self.num_channels)
            masked_data += (1 - masks_np) * 0.5
        return ori_data, masked_data, gen_data
    def latent_traversal(self,
                         context,
                         traversal_range=[-6, 6],
                         num_traversal_step=13,
                         mgen=None):
        self.num_traversal_step = num_traversal_step
        if self.num_channels == 3:
            image = np.cast[np.float32]((context - 127.5) / 127.5)
        else:
            image = context
        num_instances = num_traversal_step * self.z_dim
        assert num_instances <= self.batch_size, "cannot feed all the instances into GPUs"
        data = np.stack([image.copy() for i in range(self.batch_size)], axis=0)
        ori_data = data.copy()
        ds = np.split(data.copy(), self.nr_gpu)

        feed_dict = {}
        feed_dict.update({model.is_training: False for model in self.models})
        feed_dict.update({model.dropout_p: 0.0 for model in self.models})
        feed_dict.update(
            {model.x: ds[i]
             for i, model in enumerate(self.models)})

        masks_np = [
            mgen.gen(self.batch_size // self.nr_gpu)
            for i in range(self.nr_gpu)
        ]
        if self.phase == 'pvae':
            feed_dict.update({
                model.masks: np.zeros_like(masks_np[i])
                for i, model in enumerate(self.models)
            })
        elif self.phase == 'ce':
            feed_dict.update({
                model.masks: masks_np[i]
                for i, model in enumerate(self.models)
            })
        feed_dict.update({
            model.input_masks: masks_np[i]
            for i, model in enumerate(self.models)
        })

        ret = self.sess.run([model.z_mu for model in self.models] +
                            [model.z_log_sigma_sq for model in self.models],
                            feed_dict=feed_dict)
        z_mu = np.concatenate(ret[:len(ret) // 2], axis=0)
        z_log_sigma_sq = np.concatenate(ret[len(ret) // 2:], axis=0)
        z_sigma = np.sqrt(np.exp(z_log_sigma_sq))
        z = z_mu  #np.random.normal(loc=z_mu, scale=z_sigma)
        for i in range(z.shape[0]):
            z[i] = z[0].copy()
        for i in range(self.z_dim):
            z[i * num_traversal_step:(i + 1) * num_traversal_step,
              i] = np.linspace(start=traversal_range[0],
                               stop=traversal_range[1],
                               num=num_traversal_step)
        z = np.split(z, self.nr_gpu)

        feed_dict.update(
            {model.z_ph: z[i]
             for i, model in enumerate(self.models)})
        feed_dict.update({model.use_z_ph: True for model in self.models})

        for i in range(self.nr_gpu):
            ds[i] *= broadcast_masks_np(masks_np[i],
                                        num_channels=self.num_channels)
        masked_data = np.concatenate(ds, axis=0)
        x_gen = [ds[i].copy() for i in range(self.nr_gpu)]
        for yi in range(self.img_size):
            for xi in range(self.img_size):
                if np.min(
                        np.array([
                            masks_np[i][:, yi, xi] for i in range(self.nr_gpu)
                        ])) > 0:
                    continue
                feed_dict.update({
                    model.x_bar: x_gen[i]
                    for i, model in enumerate(self.models)
                })
                x_hats = self.sess.run([model.x_hat for model in self.models],
                                       feed_dict=feed_dict)
                for i in range(self.nr_gpu):
                    bmask = broadcast_masks_np(masks_np[i][:, yi, xi],
                                               num_channels=self.num_channels)
                    x_gen[i][:, yi, xi, :] = x_hats[i][:, yi, xi, :] * (
                        1. - bmask) + x_gen[i][:, yi, xi, :] * bmask
        gen_data = np.concatenate(x_gen, axis=0)
        return ori_data[:
                        num_instances], masked_data[:
                                                    num_instances], gen_data[:
                                                                             num_instances]
    def controlled_sample(self, data, mgen):
        if self.num_channels == 3:
            data = np.cast[np.float32]((data - 127.5) / 127.5)
        ori_data = data.copy()
        ds = np.split(data.copy(), self.nr_gpu)
        feed_dict = {}
        feed_dict.update({model.is_training: False for model in self.models})
        feed_dict.update({model.dropout_p: 0.0 for model in self.models})
        feed_dict.update(
            {model.x: ds[i]
             for i, model in enumerate(self.models)})
        feed_dict.update(
            {model.x_bar: ds[i]
             for i, model in enumerate(self.models)})

        masks_np = [
            mgen.gen(self.batch_size // self.nr_gpu)
            for i in range(self.nr_gpu)
        ]
        if self.phase == 'pvae':
            feed_dict.update({
                model.masks: np.zeros_like(masks_np[i])
                for i, model in enumerate(self.models)
            })
        elif self.phase == 'ce':
            feed_dict.update({
                model.masks: masks_np[i]
                for i, model in enumerate(self.models)
            })
        feed_dict.update({
            model.input_masks: masks_np[i]
            for i, model in enumerate(self.models)
        })

        ret = self.sess.run([model.z_mu for model in self.models] +
                            [model.z_log_sigma_sq for model in self.models],
                            feed_dict=feed_dict)
        z_mu = np.concatenate(ret[:len(ret) // 2], axis=0)
        z_log_sigma_sq = np.concatenate(ret[len(ret) // 2:], axis=0)
        z_sigma = np.sqrt(np.exp(z_log_sigma_sq))
        z = np.random.normal(loc=z_mu, scale=z_sigma)

        # for i in range(z.shape[0]):
        #     z[i] = z[0]
        # z += np.random.normal(loc=np.zeros_like(z_mu), scale=np.ones_like(z_sigma))

        for i in range(8):
            for j in range(4):
                z[i * 4 + j] = z[i * 4]

        z = np.split(z, self.nr_gpu)
        feed_dict.update(
            {model.z_ph: z[i]
             for i, model in enumerate(self.models)})
        feed_dict.update({model.use_z_ph: True for model in self.models})

        for i in range(self.nr_gpu):
            ds[i] *= broadcast_masks_np(masks_np[i],
                                        num_channels=self.num_channels)
        masked_data = np.concatenate(ds, axis=0)
        x_gen = [ds[i].copy() for i in range(self.nr_gpu)]
        for yi in range(self.img_size):
            for xi in range(self.img_size):
                feed_dict.update({
                    model.x_bar: x_gen[i]
                    for i, model in enumerate(self.models)
                })
                x_hats = self.sess.run([model.x_hat for model in self.models],
                                       feed_dict=feed_dict)
                for i in range(self.nr_gpu):
                    x_gen[i][:, yi, xi, :] = x_hats[i][:, yi, xi, :]
        gen_data = np.concatenate(x_gen, axis=0)
        return ori_data, masked_data, gen_data
    def sample(self, data, mgen, same_inputs=False, use_mask_at=None):
        if self.num_channels == 3:
            data = np.cast[np.float32]((data - 127.5) / 127.5)
        if same_inputs:
            for i in range(data.shape[0]):
                data[i] = data[3]
        ori_data = data.copy()
        ds = np.split(data.copy(), self.nr_gpu)
        feed_dict = {}
        feed_dict.update({model.is_training: False for model in self.models})
        feed_dict.update({model.dropout_p: 0.0 for model in self.models})
        feed_dict.update(
            {model.x: ds[i]
             for i, model in enumerate(self.models)})
        feed_dict.update(
            {model.x_bar: ds[i]
             for i, model in enumerate(self.models)})

        if use_mask_at is not None:
            masks_np = np.load(use_mask_at)['masks']
            masks_np = np.split(masks_np, self.nr_gpu)
        else:
            masks_np = [
                mgen.gen(self.batch_size // self.nr_gpu)
                for i in range(self.nr_gpu)
            ]
            np.savez(mgen.name + "_" + self.data_set,
                     masks=np.concatenate(masks_np))

        # masks_np = [mgen.gen(self.batch_size//self.nr_gpu) for i in range(self.nr_gpu)]

        if self.phase == 'pvae':
            feed_dict.update({
                model.masks: np.zeros_like(masks_np[i])
                for i, model in enumerate(self.models)
            })
        elif self.phase == 'ce':
            feed_dict.update({
                model.masks: masks_np[i]
                for i, model in enumerate(self.models)
            })
        feed_dict.update({
            model.input_masks: masks_np[i]
            for i, model in enumerate(self.models)
        })

        ret = self.sess.run([model.z_mu for model in self.models] +
                            [model.z_log_sigma_sq for model in self.models],
                            feed_dict=feed_dict)
        z_mu = np.concatenate(ret[:len(ret) // 2], axis=0)
        z_log_sigma_sq = np.concatenate(ret[len(ret) // 2:], axis=0)
        z_sigma = np.sqrt(np.exp(z_log_sigma_sq))
        z = np.random.normal(loc=z_mu, scale=z_sigma)
        z = np.split(z, self.nr_gpu)
        feed_dict.update(
            {model.z_ph: z[i]
             for i, model in enumerate(self.models)})
        feed_dict.update({model.use_z_ph: True for model in self.models})

        for i in range(self.nr_gpu):
            ds[i] *= broadcast_masks_np(masks_np[i],
                                        num_channels=self.num_channels)
        masked_data = np.concatenate(ds, axis=0)
        x_gen = [ds[i].copy() for i in range(self.nr_gpu)]
        for yi in range(self.img_size):
            for xi in range(self.img_size):
                if np.min(
                        np.array([
                            masks_np[i][:, yi, xi] for i in range(self.nr_gpu)
                        ])) > 0:
                    continue
                feed_dict.update({
                    model.x_bar: x_gen[i]
                    for i, model in enumerate(self.models)
                })
                x_hats = self.sess.run([model.x_hat for model in self.models],
                                       feed_dict=feed_dict)
                for i in range(self.nr_gpu):
                    bmask = broadcast_masks_np(masks_np[i][:, yi, xi],
                                               num_channels=self.num_channels)
                    x_gen[i][:, yi, xi, :] = x_hats[i][:, yi, xi, :] * (
                        1. - bmask) + x_gen[i][:, yi, xi, :] * bmask
        gen_data = np.concatenate(x_gen, axis=0)
        if self.num_channels == 1:
            masks_np = np.concatenate(masks_np, axis=0)
            masks_np = broadcast_masks_np(masks_np,
                                          num_channels=self.num_channels)
            masked_data += (1 - masks_np) * 0.5
        return ori_data, masked_data, gen_data
Ejemplo n.º 7
0
with tf.Session(config=config) as sess:

    sess.run(initializer)

    ckpt_file = args.save_dir + '/params_' + args.data_set + '.ckpt'
    print('restoring parameters from', ckpt_file)
    saver.restore(sess, ckpt_file)

    sample_mgen = get_generator('bottom quarter', args.img_size)
    fill_region = sample_mgen.gen(1)[0]
    # sample_mgen = get_generator('transparent', args.img_size)
    # fill_region = get_generator('full', args.img_size).gen(1)[0]
    data = next(test_data)

    from blocks.helpers import broadcast_masks_np
    data = data.astype(np.float32) * broadcast_masks_np(fill_region, 3)

    test_data.reset()
    # vdata = np.cast[np.float32]((data - 127.5) / 127.5)
    # visualize_samples(vdata, "/data/ziz/jxu/gpu-results/show_original.png", layout=[8,8])

    img = []
    for i in [7]:  #[5,7,8]: #[5, 7, 8, 18, 27, 44, 74, 77]:
        sample_x = latent_traversal(sess,
                                    data[i],
                                    traversal_range=[-6, 6],
                                    num_traversal_step=13,
                                    fill_region=fill_region,
                                    mgen=sample_mgen)
        view = visualize_samples(sample_x,
                                 None,