コード例 #1
0
    def _measure_disentanglement(self, autoencoder):
        beta = autoencoder.get_beta()
        train_classifier_inputs, train_classifier_labels = \
            Evaluator.inputs_and_labels(autoencoder, self.classifier_train_json)
        logger.info('Beta = {0} | Classifier training data processed.'.format(beta))
        test_classifier_inputs, test_classifier_labels = \
            Evaluator.inputs_and_labels(autoencoder, self.classifier_test_json)
        logger.info('Beta = {0} | Classifier test data processed.'.format(beta))
        valid_classifier_inputs, valid_classifier_labels = \
            Evaluator.inputs_and_labels(autoencoder, self.classifier_validation_json)
        logger.info('Beta = {0} | Classifier validation data processed.'.format(beta))

        autoencoder.close_session()

        # Constants.
        CLASSIFIER_INPUT_DIMENSION = autoencoder.get_code_dimension()

        logger.info('Beta = {0} | Create classifier.'.format(beta))
        classifier = LinearClassifier(CLASSIFIER_INPUT_DIMENSION, CLASSIFIER_POSSIBLE_CLASSES_COUNT)

        epoch = 0
        best_validation_score = np.inf
        consecutive_decreases = 0
        LEN_TRAIN = len(train_classifier_inputs)
        while True:
            # Train with mini-batches.
            random_permutation = np.random.permutation(np.arange(LEN_TRAIN))
            shuffled_train_inputs = train_classifier_inputs[random_permutation]
            shuffled_train_labels = train_classifier_labels[random_permutation]
            classifier.partial_fit(shuffled_train_inputs, shuffled_train_labels)

#            for start in range(0, LEN_TRAIN, BATCH_SIZE):
#                end = LEN_TRAIN if LEN_TRAIN - start < 2 * BATCH_SIZE else start + BATCH_SIZE
#                classifier.partial_fit(shuffled_train_inputs[start:end],
#                    shuffled_train_labels[start:end])
#                if end == LEN_TRAIN:
#                    break

            # Early stopping.
            validation_score = classifier.get_cost(valid_classifier_inputs, valid_classifier_labels)
            logger.info('Beta = {0} | Classifier epoch {1} validation cost: {2}'.format(
                beta, epoch, validation_score))
            if best_validation_score > validation_score:
                logger.info('Beta = {0} | *** OPTIMAL SO FAR ***'.format(beta))
                best_validation_score = validation_score
                consecutive_decreases = 0
            else:
                consecutive_decreases += 1
                if consecutive_decreases > PATIENCE:
                    break

            epoch += 1

        logger.info('Beta = {0} | Classifier training completed.'.format(beta))

        accuracy = classifier.accuracy(test_classifier_inputs, test_classifier_labels)
        logger.info('Beta = {0} | Classifier accuracy: {1}'.format(beta, accuracy))
        with open(self.accuracy_output_file, 'w') as f:
            f.write('Beta = {0} | Classifier accuracy: {1}\n'.format(beta, accuracy))
        self.classifier_accuracy_all.append((beta, accuracy))
        classifier.close_session()