示例#1
0
def interactive_interp(sess, dcgan, config, sampling='uniform'):
    while True:
        z_samples = dcgan.z_sampler(config)
        has_labels = False
        try:
            if dcgan.has_labels:
                has_labels = True
                label = int(raw_input('Class label for first sample: '))
                sample_labels = np.eye(dcgan.y_dim)[np.full(dcgan.batch_size, label)]
        except Exception: pass
        gauss_filter = gauss_kernel_fixed(config.gauss_sigma, config.gauss_trunc)
        if has_labels:
            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_samples, dcgan.y: sample_labels,
                                                         dcgan.gauss_kernel: gauss_filter})
        else:
            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_samples, dcgan.gauss_kernel: gauss_filter})
        samples = np.array([((sample + 1) / 2 * 255).astype(np.uint8) for sample in samples])
        grid_size = int(math.ceil(math.sqrt(dcgan.batch_size)))
        scipy.misc.imshow(merge(samples, (grid_size, grid_size)))
        # from IPython import embed; embed()
        start = int(raw_input('First sample number: '))
        if has_labels:
            label2 = raw_input('Class label for second sample [same]: ')
            if label2 == '':
                label2 = label
                same = True
            else:
                label2 = int(label2)
                same = False
            sample_labels2 = np.eye(dcgan.y_dim)[np.full(dcgan.batch_size, label2)]
            if same:
                samples2 = samples
            else:
                samples2 = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_samples, dcgan.y: sample_labels2,
                                    dcgan.gauss_kernel: gauss_filter})
            scipy.misc.imshow(merge(samples2, (grid_size, grid_size)))
        stop = int(raw_input('Second sample number: '))
        n_steps = raw_input('Number of steps [62]: ')
        if n_steps == '':
            n_steps = 62
        else:
            n_steps = int(n_steps)
        if has_labels:
            series = interpolate(sess, dcgan, z_start=z_samples[start - 1], z_stop=z_samples[stop - 1],
                                 n_steps=n_steps, y_start=label, y_stop=label2, transform=True)
        else:
            series = interpolate(sess, dcgan, z_start=z_samples[start-1], z_stop=z_samples[stop-1],
                                 n_steps=n_steps, transform=True)
        scipy.misc.imshow(merge(series, (int(math.ceil((n_steps + 2) / 8)), 8)))
        c = raw_input('Continue? [y/n]')
        if c != 'y':
            break
示例#2
0
def interpolate(sess, dcgan, z_start, z_stop, n_steps=62, y_start=None, y_stop=None, transform=True):
    """Interpolates between two samples in z-space
    
    Input parameters:
    sess: TF session
    dcgan: DCGAN object for sampling
    z_start: z-vector of the first sample
    z_start: z-vector of the second sample
    n_steps: number of intermediate samples to produce
    sampling: the sampling method used for training ['uniform']
    transform: if True, the pixel values will be transformed to their normal image range [True]
    y_start: label for first sample (numerical)
    y_stop: label for second sample (numerical)
    
    RETURNS an array of n_steps+2 samples"""
    y_dim = 0
    if y_start is not None:
        y_dim = dcgan.y_dim
        if y_stop is None:
            y_stop = y_start
        if y_start != y_stop:
            z_start = np.concatenate((z_start, np.eye(y_dim)[y_start]))
            z_stop = np.concatenate((z_stop, np.eye(y_dim)[y_stop]))
    # limit to batch size for simplicity
    if n_steps > (dcgan.batch_size - 2):
        n_steps = dcgan.batch_size - 2
    # sample along big circle for all distributions
    steps = np.linspace(0, 1, n_steps + 2)
    z_samples = [slerp(step, z_start, z_stop) for step in steps]
    gauss_filter = gauss_kernel_fixed(dcgan.gauss_sigma, (dcgan.kernel_size - 1) // 2)
    if n_steps != (dcgan.batch_size - 2):
        z_samples += [np.zeros(dcgan.z_dim + y_dim) for i in range(dcgan.batch_size - n_steps - 2)]
    if y_dim > 0:
        if y_start != y_stop:
            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: np.array(z_samples)[:, :dcgan.z_dim],
                                                         dcgan.y: np.array(z_samples)[:, dcgan.z_dim:],
                                                         dcgan.gauss_kernel: gauss_filter},)[:n_steps + 2]
        else:
            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: np.array(z_samples),
                                                         dcgan.y: np.eye(y_dim)
                                                         [np.full(dcgan.batch_size, y_start)],
                                                         dcgan.gauss_kernel: gauss_filter})[:n_steps + 2]
    else:
         samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: np.array(z_samples),
                                                      dcgan.gauss_kernel: gauss_filter})[:n_steps+2]
    if transform:
        samples = np.array([((sample + 1) / 2 * 255).astype(np.uint8) for sample in samples])
    return samples
示例#3
0
    def train(self, config):
        """Train DCGAN"""

        train_gen, data_len = hdf5_images.load(batch_size=self.batch_size,
                                               data_file=self.data_dir,
                                               resolution=self.output_height,
                                               label_name=self.label_path)

        def inf_train_gen():
            while True:
                for _images, _labels in train_gen():
                    yield _images, _labels

        gen = inf_train_gen()

        d_optim = tf.train.AdamOptimizer(self.lr, beta1=config.beta1) \
            .minimize(self.d_loss, var_list=self.d_vars)
        g_optim = tf.train.AdamOptimizer(self.lr, beta1=config.beta1) \
            .minimize(self.g_loss, var_list=self.g_vars)

        try:
            tf.global_variables_initializer().run()
        except:
            tf.initialize_all_tables().run()
        self.g_sum = merge_summary(
            [self.G_sum, self.d_loss_fake_sum,
             self.g_loss_sum])  #self.z_sum, self.d__sum (hist)
        self.d_sum = merge_summary(
            [self.lr_sum, self.z_sum, self.d_loss_real_sum,
             self.d_loss_sum])  #self.d_sum (hist)
        # label weight histogram
        # if self.has_labels:
        #     self.g_sum = merge_summary([self.g_sum, self.w_y_sum])

        # initialize summary writer: for each run create (enumerated) sub-directory
        if os.path.exists('./logs/' + self.dataset_name):
            # number of existing immediate child directories of log folder
            run_var = len(next(os.walk('./logs/' + self.dataset_name))[1]) + 1
        else:
            run_var = 1
        self.writer = SummaryWriter(
            '%s/%s' % ('./logs/' + self.dataset_name, run_var),
            self.sess.graph)

        sample_z = self.z_sampler(config)
        sample_inputs = gen.next()
        sample_images = sample_inputs[0]
        sample_images = sample_images.transpose((0, 2, 3, 1))
        if self.has_labels:
            # use one of the following two lines to get either random samples or only samples from the first 8 classes
            # sample_y = np.eye(self.y_dim)[sample_inputs[1]]
            sample_y = np.eye(self.y_dim)[[1, 2, 3, 4, 5, 6, 7, 8] * 8]

        counter = 1
        start_time = time.time()
        could_load, checkpoint_counter = self.load(self.checkpoint_dir)
        if could_load:
            counter = checkpoint_counter
            print(" [*] Load SUCCESS")
        else:
            print(" [!] Load failed...")

        batch_idxs = data_len // self.batch_size
        for epoch in xrange(config.epoch):
            if self.blur_fade:
                # sigma_used = self.gauss_sigma * (1 - (epoch / config.epoch))
                # tapered version
                full_blur_ep = 0  # int(config.epoch * 0.2)
                no_blur_ep = 6  # int(config.epoch * 0.2)
                fade_ep = config.epoch - full_blur_ep - no_blur_ep
                if epoch < full_blur_ep:
                    sigma_used = self.gauss_sigma
                elif epoch < full_blur_ep + fade_ep:
                    sigma_used = self.gauss_sigma * (1 - (
                        (epoch + 1 - full_blur_ep) / fade_ep))
                else:
                    sigma_used = 0
            else:
                sigma_used = self.gauss_sigma
            kernel_used = gauss_kernel_fixed(sigma_used,
                                             (self.kernel_size - 1) // 2)
            for idx in xrange(0, batch_idxs):
                batch_z = self.z_sampler(config)
                # when using random labels (commented-out line below), the fake label distribution
                # might not match the data and training can fail!
                # batch_y = tf.one_hot(np.random.random_integers(0, self.y_dim - 1, config.batch_size), self.y_dim)
                # better use the real images and labels:
                batch_images, batch_labels_num = gen.next()
                # if images are stored in BHWC-format, the following line should be commented out
                batch_images = batch_images.transpose(
                    (0, 2, 3, 1))  #这里进行了维度转换!——————————————————————————
                if self.has_labels:
                    batch_labels = np.eye(self.y_dim)[batch_labels_num]
                if config.blur_input is not None:
                    batch_images = gaussian_filter(
                        batch_images,
                        [0, config.blur_input, config.blur_input, 0])

                # Update D network
                _, summary_str = self.sess.run(
                    [d_optim, self.d_sum],
                    feed_dict={
                        self.inputs: batch_images,
                        self.z: batch_z,
                        self.lr: config.learning_rate,
                        self.gauss_kernel: kernel_used
                    } if not self.has_labels else {
                        self.inputs: batch_images,
                        self.y: batch_labels,
                        self.lr: config.learning_rate,
                        self.z: batch_z,
                        self.gauss_kernel: kernel_used
                    })
                self.writer.add_summary(summary_str, counter)

                # Update G network specified number of times
                for i in range(self.num_g_updates):
                    _, summary_str = self.sess.run(
                        [g_optim, self.g_sum],
                        feed_dict={
                            self.inputs: batch_images,
                            self.lr: config.learning_rate,
                            self.z: batch_z,
                            self.gauss_kernel: kernel_used
                        } if not self.has_labels else {
                            self.inputs: batch_images,
                            self.y: batch_labels,
                            self.lr: config.learning_rate,
                            self.z: batch_z,
                            self.gauss_kernel: kernel_used
                        })
                    self.writer.add_summary(summary_str, counter)

                # Get losses for current batch
                errD_fake = self.d_loss_fake.eval(
                    {
                        self.z: batch_z,
                        self.gauss_kernel: kernel_used
                    } if not self.has_labels else {
                        self.z: batch_z,
                        self.y: batch_labels,
                        self.gauss_kernel: kernel_used
                    })
                errD_real = self.d_loss_real.eval(
                    {
                        self.inputs: batch_images,
                        self.gauss_kernel: kernel_used
                    } if not self.has_labels else {
                        self.inputs: batch_images,
                        self.y: batch_labels,
                        self.gauss_kernel: kernel_used
                    })
                errG = self.g_loss.eval({
                    self.z: batch_z,
                    self.gauss_kernel: kernel_used
                } if not self.has_labels else {
                    self.z: batch_z,
                    self.y: batch_labels,
                    self.gauss_kernel: kernel_used
                })

                # add sigma to summary
                sigma_sum = tf.Summary(value=[
                    tf.Summary.Value(tag='gauss_sigma',
                                     simple_value=sigma_used)
                ])
                self.writer.add_summary(sigma_sum, counter)

                counter += 1
                print(
                    "Epoch: [%2d/%2d] [%4d/%4d] [%4d] time: %4.4f, d_loss: %.8f, g_loss: %.8f, sigma: %4f"
                    % (epoch, config.epoch, idx, batch_idxs, counter,
                       time.time() - start_time, errD_fake + errD_real, errG,
                       sigma_used))
                # print losses for constant sample and save image with corresponding output
                if (counter % 500) == 1:
                    try:
                        samples, d_loss, g_loss = self.sess.run(
                            [self.sampler, self.d_loss, self.g_loss],
                            feed_dict={
                                self.z: sample_z,
                                self.inputs: sample_images,
                                self.y: sample_y,
                                self.gauss_kernel: kernel_used
                            } if self.has_labels else {
                                self.z: sample_z,
                                self.inputs: sample_images,
                                self.gauss_kernel: kernel_used
                            })
                        manifold_h = int(np.ceil(np.sqrt(samples.shape[0])))
                        manifold_w = int(np.floor(np.sqrt(samples.shape[0])))
                        save_images(
                            samples, [manifold_h, manifold_w],
                            './{}/train_{:02d}_{:04d}.png'.format(
                                config.sample_dir, epoch, idx))
                        print("[Sample] d_loss: %.8f, g_loss: %.8f" %
                              (d_loss, g_loss))
                    except Exception, e:
                        print("Error saving sample image:")
                        print e
                # save checkpoint
                if np.mod(counter, 3000) == 2:
                    self.save(config.checkpoint_dir, counter)
示例#4
0
def visualize(sess, dcgan, config, option):
    image_frame_dim = int(math.ceil(config.batch_size**.5))
    # produce sample uniformly with nearest neighbour
    # option 0: additionally sort according to distance
    if (option == 1) or (option == 0):
        n_images = 20
        has_labels = False
        try:
            if dcgan.has_labels:
                # generate one image for each cluster / category
                has_labels = True
                if option == 0:
                    n_images = dcgan.y_dim
        except Exception:
            pass
        # sample DCGAN from uniform distribution in z
        print('sampling...')
        z_samples = dcgan.z_sampler(config)
        if has_labels:
            y_samples = np.eye(dcgan.y_dim)[np.random.choice(
                dcgan.y_dim, [n_images, config.batch_size])]
            samples = (z_samples, y_samples)
            samples = np.array([
                sess.run(dcgan.sampler, {
                    dcgan.z: batch,
                    dcgan.y: batch_y
                }) for batch, batch_y in samples
            ])
        else:
            samples = np.array([
                sess.run(dcgan.sampler, feed_dict={dcgan.z: batch})
                for batch in z_samples
            ])
        # transform back to normal image value range and reshape to one array instead of batches
        print('transforming...')
        samples = np.array([((sample + 1) / 2 * 255).astype(np.uint8) for sample in samples]) \
            .reshape((samples.shape[0] * samples.shape[1],) + samples.shape[2:])
        # load and rescale training data to same size as samples
        print('loading and transforming orig data...')
        orig_data, _ = fh.load_icon_data(config.data_dir)
        orig_data = np.array([
            scipy.misc.imresize(icon,
                                (config.output_height, config.output_height))
            for icon in orig_data
        ])
        # get nearest neighbour indices from training set
        if option == 1:
            print('getting nearest neighbours...')
            nearest_idxs = metrics.nearest_icons(samples, orig_data)
        else:
            print('getting nearest neighbours...')
            nearest_idxs, distances = metrics.nearest_icons(samples,
                                                            orig_data,
                                                            get_dist=True)
            print('sorting...')
            # normalize distance over whole image content to prevent predominantly white images having low distance
            norms = np.sqrt(np.sum(np.power(samples, 2), axis=(1, 2, 3)))
            distances = np.array(
                [distance / n for distance, n in zip(distances, norms)])
            sorting = np.argsort(distances)
            # import ipdb; ipdb.set_trace()
            samples = samples[sorting]
            nearest_idxs = np.array(nearest_idxs)[sorting]
        bs = config.batch_size
        for idx in xrange(n_images):
            print(" [*] %d" % idx)
            combined = []
            # combine samples and nearest neighbours for each batch and save as png
            for sample, orig in zip(
                    samples[idx * bs:(idx + 1) * bs],
                    orig_data[nearest_idxs[idx * bs:(idx + 1) * bs]]):
                combined += [sample, orig]
            scipy.misc.imsave(
                os.path.join(config.sample_dir,
                             'test_uniform_nearest_%s.png' % (idx)),
                merge(np.array(combined),
                      [image_frame_dim, image_frame_dim * 2]))
    # sample with uniform distribution
    if option == 2:
        n_images = 20
        has_labels = False
        try:
            if dcgan.has_labels:
                # generate one image for each cluster / category
                n_images = dcgan.y_dim
                has_labels = True
        except Exception:
            pass
        for idx in xrange(n_images):
            print(" [*] %d" % idx)
            z_sample = dcgan.z_sampler(config)
            # create gaussian convolution kernel as defined in run parameters
            kernel = gauss_kernel_fixed(config.gauss_sigma, config.gauss_trunc)
            if has_labels:
                # y = np.random.choice(dcgan.y_dim, config.batch_size)
                # y_one_hot = np.zeros((config.batch_size, dcgan.y_dim))
                # y_one_hot[np.arange(config.batch_size), y] = 1
                y_one_hot = np.eye(dcgan.y_dim)[np.full(
                    config.batch_size, idx)]
                # print(y_one_hot)
                samples = sess.run(dcgan.sampler,
                                   feed_dict={
                                       dcgan.z: z_sample,
                                       dcgan.y: y_one_hot,
                                       dcgan.gauss_kernel: kernel
                                   })
            else:
                samples = sess.run(dcgan.sampler,
                                   feed_dict={
                                       dcgan.z: z_sample,
                                       dcgan.gauss_kernel: kernel
                                   })

            save_images(
                samples, [image_frame_dim, image_frame_dim],
                os.path.join(config.sample_dir, 'test_uniform_%s.png' % (idx)))
    # sample with normal distribution
    if option == 3:
        for idx in xrange(100):
            print(" [*] %d" % idx)
            z_sample = np.random.normal(size=(config.batch_size, dcgan.z_dim))
            samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
            save_images(samples, [image_frame_dim, image_frame_dim],
                        './samples/test_normal_%s.png' % (idx))
    # single sample with uniform distribution
    if option == 4:
        z_sample = np.random.uniform(-0.5,
                                     0.5,
                                     size=(config.batch_size, dcgan.z_dim))
        samples = sess.run(dcgan.sampler, feed_dict={dcgan.z: z_sample})
        save_images(
            samples, [image_frame_dim, image_frame_dim],
            os.path.join(config.sample_dir,
                         'test_%s.png' % strftime("%Y%m%d%H%M%S", gmtime())))
    # vary single z-component only
    if option == 5:
        values = np.arange(0, 1, 1. / config.batch_size)
        for idx in xrange(100):
            print(" [*] %d" % idx)
            z_sample = np.zeros([config.batch_size, dcgan.z_dim])
            for kdx, z in enumerate(z_sample):
                z[idx] = values[kdx]
            save_images(
                samples, [image_frame_dim, image_frame_dim],
                os.path.join(config.sample_dir, 'test_arange_%s.png' % (idx)))