Пример #1
0
    def visualize_weights(self, layer, imsize, layout):
        """
        Displays the weights of a specified layer as images.
        :param layer: the layer whose weights to display
        :param imsize: the image size
        :param layout: number of rows and columns for each page
        :return: none
        """

        util.disp_imdata(self.Ws[layer].get_value().T, imsize, layout)
        plt.show(block=False)
Пример #2
0
    def show_images(self, split):
        """
        Displays the images in a given split.
        :param split: string
        """

        # get split
        data_split = getattr(self, split, None)
        if data_split is None:
            raise ValueError('Invalid data split')

        # display images
        util.disp_imdata(data_split.x, self.image_size, [6, 10])

        plt.show()
Пример #3
0
    def show_images(self, split):
        """
        Displays the images in a given split.
        :param split: string
        """

        # get split
        data = getattr(self, split, None)
        if data is None:
            raise ValueError('Invalid data split')

        # display images
        x = np.stack([data.r, data.g, data.b], axis=2)
        util.disp_imdata(x, self.image_size, [6, 10])

        plt.show()
Пример #4
0
    def show_images(self, split):
        """
        Displays the images in a given split.
        :param split: string
        """

        # get split
        data_split = getattr(self, split, None)
        if data_split is None:
            raise ValueError('Invalid data split')

        # add a pixel at the bottom right
        last_pixel = -np.sum(data_split.x, axis=1)
        images = np.hstack([data_split.x, last_pixel[:, np.newaxis]])

        # display images
        util.disp_imdata(images, self.image_size, [6, 10])

        plt.show()
Пример #5
0
def evaluate(model, split, n_samples=None):
    """
    Evaluate a trained model.
    :param model: the model to evaluate. Can be any made, maf, or real nvp
    :param split: string, the data split to evaluate on. Must be 'trn', 'val' or 'tst'
    :param n_samples: number of samples to generate from the model, or None for no samples
    """

    assert is_data_loaded(), 'Dataset hasn\'t been loaded'

    # choose which data split to evaluate on
    data_split = getattr(data, split, None)
    if data_split is None:
        raise ValueError('Invalid data split')

    if is_conditional(model):

        # calculate log probability
        logprobs = model.eval([data_split.y, data_split.x])
        print('logprob(x|y) = {0:.2f} +/- {1:.2f}').format(
            logprobs.mean(), 2 * logprobs.std() / np.sqrt(data_split.N))

        # classify test set
        logprobs = np.empty([data_split.N, data.n_labels])
        for i in xrange(data.n_labels):
            y = np.zeros([data_split.N, data.n_labels])
            y[:, i] = 1
            logprobs[:, i] = model.eval([y, data_split.x])
        predict_label = np.argmax(logprobs, axis=1)
        accuracy = (predict_label == data_split.labels).astype(float)
        logprobs = scipy.misc.logsumexp(logprobs, axis=1) - np.log(
            logprobs.shape[1])
        print('logprob(x) = {0:.2f} +/- {1:.2f}'.format(
            logprobs.mean(), 2 * logprobs.std() / np.sqrt(data_split.N)))
        print('classification accuracy = {0:.2%} +/- {1:.2%}'.format(
            accuracy.mean(), 2 * accuracy.std() / np.sqrt(data_split.N)))

        # generate data conditioned on label
        if n_samples is not None:
            for i in xrange(data.n_labels):

                # generate samples and sort according to log prob
                y = np.zeros(data.n_labels)
                y[i] = 1
                samples = model.gen(y, n_samples)
                lp_samples = model.eval([np.tile(y, [n_samples, 1]), samples])
                lp_samples = lp_samples[np.logical_not(np.isnan(lp_samples))]
                idx = np.argsort(lp_samples)
                samples = samples[idx][::-1]

                if data_name == 'mnist':
                    samples = (util.logistic(samples) -
                               data.alpha) / (1 - 2 * data.alpha)

                elif data_name == 'bsds300':
                    samples = np.hstack(
                        [samples, -np.sum(samples, axis=1)[:, np.newaxis]])

                elif data_name == 'cifar10':
                    samples = (util.logistic(samples) -
                               data.alpha) / (1 - 2 * data.alpha)
                    D = int(data.n_dims / 3)
                    r = samples[:, :D]
                    g = samples[:, D:2 * D]
                    b = samples[:, 2 * D:]
                    samples = np.stack([r, g, b], axis=2)

                else:
                    raise ValueError('non-image dataset')

                util.disp_imdata(samples, data.image_size, [5, 8])

    else:

        # calculate average log probability
        logprobs = model.eval(data_split.x)
        print('logprob(x) = {0:.2f} +/- {1:.2f}'.format(
            logprobs.mean(), 2 * logprobs.std() / np.sqrt(data_split.N)))

        # generate data
        if n_samples is not None:

            # generate samples and sort according to log prob
            samples = model.gen(n_samples)
            lp_samples = model.eval(samples)
            lp_samples = lp_samples[np.logical_not(np.isnan(lp_samples))]
            idx = np.argsort(lp_samples)
            samples = samples[idx][::-1]

            if data_name == 'mnist':
                samples = (util.logistic(samples) -
                           data.alpha) / (1 - 2 * data.alpha)

            elif data_name == 'bsds300':
                samples = np.hstack(
                    [samples, -np.sum(samples, axis=1)[:, np.newaxis]])

            elif data_name == 'cifar10':
                samples = (util.logistic(samples) -
                           data.alpha) / (1 - 2 * data.alpha)
                D = int(data.n_dims / 3)
                r = samples[:, :D]
                g = samples[:, D:2 * D]
                b = samples[:, 2 * D:]
                samples = np.stack([r, g, b], axis=2)

            else:
                raise ValueError('non-image dataset')

            util.disp_imdata(samples, data.image_size, [5, 8])

    plt.show()