Пример #1
0
    def _load_mnist3(self, opts):
        """Load data from MNIST files.

        """
        logging.debug('Loading 3-digit MNIST')
        data_dir = _data_dir(opts)
        # pylint: disable=invalid-name
        # Let us use all the bad variable names!
        tr_X = None
        tr_Y = None
        te_X = None
        te_Y = None

        with utils.o_gfile((data_dir, 'train-images-idx3-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            tr_X = loaded[16:].reshape((60000, 28, 28, 1)).astype(np.float)

        with utils.o_gfile((data_dir, 'train-labels-idx1-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            tr_Y = loaded[8:].reshape((60000)).astype(np.int)

        with utils.o_gfile((data_dir, 't10k-images-idx3-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            te_X = loaded[16:].reshape((10000, 28, 28, 1)).astype(np.float)

        with utils.o_gfile((data_dir, 't10k-labels-idx1-ubyte'), 'rb') as fd:
            loaded = np.frombuffer(fd.read(), dtype=np.uint8)
            te_Y = loaded[8:].reshape((10000)).astype(np.int)

        tr_Y = np.asarray(tr_Y)
        te_Y = np.asarray(te_Y)

        X = np.concatenate((tr_X, te_X), axis=0)
        y = np.concatenate((tr_Y, te_Y), axis=0)

        num = opts['mnist3_dataset_size']
        ids = np.random.choice(len(X), (num, 3), replace=True)
        if opts['mnist3_to_channels']:
            # Concatenate 3 digits ito 3 channels
            X3 = np.zeros((num, 28, 28, 3))
            y3 = np.zeros(num)
            for idx, _id in enumerate(ids):
                X3[idx, :, :, 0] = np.squeeze(X[_id[0]], axis=2)
                X3[idx, :, :, 1] = np.squeeze(X[_id[1]], axis=2)
                X3[idx, :, :, 2] = np.squeeze(X[_id[2]], axis=2)
                y3[idx] = y[_id[0]] * 100 + y[_id[1]] * 10 + y[_id[2]]
            self.data_shape = (28, 28, 3)
        else:
            # Concatenate 3 digits in width
            X3 = np.zeros((num, 28, 3 * 28, 1))
            y3 = np.zeros(num)
            for idx, _id in enumerate(ids):
                X3[idx, :, 0:28, 0] = np.squeeze(X[_id[0]], axis=2)
                X3[idx, :, 28:56, 0] = np.squeeze(X[_id[1]], axis=2)
                X3[idx, :, 56:84, 0] = np.squeeze(X[_id[2]], axis=2)
                y3[idx] = y[_id[0]] * 100 + y[_id[1]] * 10 + y[_id[2]]
            self.data_shape = (28, 28 * 3, 1)

        self.data = Data(opts, X3 / 255.)
        y3 = y3.astype(int)
        self.labels = y3
        self.num_points = num

        logging.debug('Training set JS=%.4f' % utils.js_div_uniform(y3))
        logging.debug('Loading Done.')
Пример #2
0
    def _evaluate_mnist3(self, opts, step, real_points,
                         fake_points, validation_fake_points, prefix=''):
        """ The model is covering as many modes and as uniformly as possible.
        Classify every picture in fake_points with a pre-trained MNIST
        classifier and compute the resulting distribution over the modes. It
        should be as close as possible to the uniform. Measure this distance
        with KL divergence. Here modes refer to labels.
        """

        assert len(fake_points) > 0, 'No fake digits to evaluate'
        num_fake = len(fake_points)

        # Classifying points with pre-trained model.
        # Pre-trained classifier assumes inputs are in [0, 1.]
        # There may be many points, so we will sess.run
        # in small chunks.

        if opts['input_normalize_sym']:
            # Rescaling data back to [0, 1.]
            if real_points is not None:
                real_points = real_points / 2. + 0.5
            fake_points = fake_points / 2. + 0.5
            if validation_fake_points  is not None:
                validation_fake_points = validation_fake_points / 2. + 0.5

        with tf.Graph().as_default() as g:
            model_file = os.path.join(opts['trained_model_path'],
                                      opts['mnist_trained_model_file'])
            saver = tf.train.import_meta_graph(model_file + '.meta')
            with tf.Session().as_default() as sess:
                saver.restore(sess, model_file)
                input_ph = tf.get_collection('X_')
                assert len(input_ph) > 0, 'Failed to load pre-trained model'
                # Input placeholder
                input_ph = input_ph[0]
                dropout_keep_prob_ph = tf.get_collection('keep_prob')
                assert len(dropout_keep_prob_ph) > 0, 'Failed to load pre-trained model'
                dropout_keep_prob_ph = dropout_keep_prob_ph[0]
                trained_net = tf.get_collection('prediction')
                assert len(trained_net) > 0, 'Failed to load pre-trained model'
                # Predicted digit
                trained_net = trained_net[0]
                logits = tf.get_collection('y_hat')
                assert len(logits) > 0, 'Failed to load pre-trained model'
                # Resulting 10 logits
                logits = logits[0]
                prob_max = tf.reduce_max(tf.nn.softmax(logits),
                                         reduction_indices=[1])

                batch_size = opts['tf_run_batch_size']
                batches_num = int(np.ceil((num_fake + 0.) / batch_size))
                result = []
                result_probs = []
                result_is_confident = []
                thresh = opts['digit_classification_threshold']
                for idx in xrange(batches_num):
                    end_idx = min(num_fake, (idx + 1) * batch_size)
                    batch_fake = fake_points[idx * batch_size:end_idx]
                    if opts['mnist3_to_channels']:
                        input1, input2, input3 = np.split(batch_fake, 3, axis=3)
                    else:
                        input1, input2, input3 = np.split(batch_fake, 3, axis=2)
                    _res1, prob1 = sess.run(
                        [trained_net, prob_max],
                        feed_dict={input_ph: input1,
                                   dropout_keep_prob_ph: 1.})
                    _res2, prob2 = sess.run(
                        [trained_net, prob_max],
                        feed_dict={input_ph: input2,
                                   dropout_keep_prob_ph: 1.})
                    _res3, prob3 = sess.run(
                        [trained_net, prob_max],
                        feed_dict={input_ph: input3,
                                   dropout_keep_prob_ph: 1.})
                    result.append(100 * _res1 + 10 * _res2 + _res3)
                    result_probs.append(
                        np.column_stack((prob1, prob2, prob3)))
                    result_is_confident.append(
                        (prob1 > thresh) * (prob2 > thresh) * (prob3 > thresh))
                result = np.hstack(result)
                result_probs = np.vstack(result_probs)
                result_is_confident = np.hstack(result_is_confident)
                assert len(result) == num_fake
                assert len(result_probs) == num_fake

        # Normalizing back
        if opts['input_normalize_sym']:
            # Rescaling data back to [0, 1.]
            if real_points is not None:
                real_points = 2. * (real_points - 0.5)
            fake_points = 2. * (fake_points - 0.5)
            if validation_fake_points  is not None:
                validation_fake_points = 2. * (validation_fake_points - 0.5)

        digits = result.astype(int)
        logging.debug(
            'Ratio of confident predictions: %.4f' %\
            np.mean(result_is_confident))
        # Plot one fake image per detected mode
        gathered = []
        points_to_plot = []
        for (idx, dig) in enumerate(list(digits)):
            if not dig in gathered and result_is_confident[idx]:
                gathered.append(dig)
                p = result_probs[idx]
                points_to_plot.append(fake_points[idx])
                logging.debug('Mode %03d covered with prob %.3f, %.3f, %.3f' %\
                              (dig, p[0], p[1], p[2]))
        # Confidence of made predictions
        conf = np.mean(result_probs)
        if len(points_to_plot) > 0:
            self._make_plots_pics(
                opts, step, None, np.array(points_to_plot), None, 'modes_')
        if np.sum(result_is_confident) == 0:
            C_actual = 0.
            C = 0.
            JS = 2.
        else:
            # Compute the actual coverage
            C_actual = len(np.unique(digits[result_is_confident])) / 1000.
            # Compute the JS with uniform
            JS = utils.js_div_uniform(digits)
            # Compute Pdata(Pmodel > t) where Pmodel( Pmodel > t ) = 0.95
            # np.percentaile(a, 10) returns t s.t. np.mean( a <= t ) = 0.1
            phat = np.bincount(digits[result_is_confident], minlength=1000)
            phat = (phat + 0.) / np.sum(phat)
            threshold = np.percentile(phat, 5)
            ratio_not_covered = np.mean(phat <= threshold)
            C = 1. - ratio_not_covered

        logging.info(
            'Evaluating: JS=%.3f, C=%.3f, C_actual=%.3f, Confidence=%.4f' %\
            (JS, C, C_actual, conf))
        return (JS, C, C_actual, conf)