示例#1
0
    def reward(state, dataset_handler, log_path="./workspace"):
        """Perform the training of the network, given (state, dataset) pair."""
        try:
            train_features, train_labels = dataset_handler.current_train_set()
            val_features, val_labels = dataset_handler.current_validation_set()

            hash_state = compute_str_hash(state_to_string(state))

            nas_trainer = EarlyStopNASTrainer(
                encoded_network=state,
                input_shape=infer_data_shape(train_features),
                n_classes=infer_n_classes(train_labels),
                batch_size=256,
                log_path="{lp}/trainer-{h}".format(lp=log_path, h=hash_state),
                mu=0.5,
                rho=0.5,
                variable_scope="cnn-{h}".format(h=hash_state))

            train_features = normalize_dataset(dataset=train_features,
                                               baseline=255)
            train_labels = train_labels.astype(np.int32)

            nas_trainer.train(
                train_data=train_features,
                train_labels=train_labels,
                train_input_fn="default",
                n_epochs=12  # As specified by BlockQNN
            )

            val_features = normalize_dataset(dataset=val_features,
                                             baseline=255)
            val_labels = val_labels.astype(np.int32)

            res = nas_trainer.evaluate(eval_data=val_features,
                                       eval_labels=val_labels,
                                       eval_input_fn="default")

            accuracy = res['accuracy']
            # Compute the refined reward as defined
            reward = accuracy*100 - nas_trainer.weighted_log_density - \
                nas_trainer.weighted_log_flops

            return reward
        except Exception as ex:  # pylint: disable=broad-except
            # TODO: Make sure exceptions are printed correctly.
            print("Reward computation failed with exception:", ex)
            return 0.
示例#2
0
    def test_train(self):
        """Test the Default Training procedure."""
        tf.reset_default_graph()
        if os.path.isdir(self.training_dir):
            shutil.rmtree(self.training_dir)

        nas_trainer = DefaultNASTrainer(
            encoded_network=self.net_nsc,
            input_shape=infer_data_shape(self.train_data),
            n_classes=infer_n_classes(self.train_labels),
            batch_size=self.batch_size,
            log_path=self.training_dir,
            variable_scope="cnn")

        nas_trainer.train(train_data=self.train_data,
                          train_labels=self.train_labels,
                          train_input_fn="default")

        self.assertTrue(os.path.isdir(self.training_dir))
示例#3
0
    def test_evaluate(self):
        """Test the Default Training procedure."""
        tf.reset_default_graph()
        if os.path.isdir(self.training_dir):
            shutil.rmtree(self.training_dir)

        nas_trainer = EarlyStopNASTrainer(
            encoded_network=self.net_nsc,
            input_shape=infer_data_shape(self.train_data),
            n_classes=infer_n_classes(self.train_labels),
            batch_size=self.batch_size,
            log_path=self.training_dir,
            mu=0.5,
            rho=0.5,
            variable_scope="cnn")

        nas_trainer.train(train_data=self.train_data,
                          train_labels=self.train_labels,
                          train_input_fn="default")

        res = nas_trainer.evaluate(eval_data=self.eval_data,
                                   eval_labels=self.eval_labels,
                                   eval_input_fn="default")

        self.assertTrue(os.path.isdir(self.training_dir))
        self.assertTrue("accuracy" in list(res.keys()))

        self.assertTrue(nas_trainer.density is not None)
        self.assertTrue(nas_trainer.density != 0.)

        self.assertTrue(nas_trainer.weighted_log_density is not None)
        self.assertTrue(nas_trainer.weighted_log_density != 0.)

        self.assertTrue(nas_trainer.flops is not None)
        self.assertTrue(nas_trainer.flops != 0.)

        self.assertTrue(nas_trainer.weighted_log_flops is not None)
        self.assertTrue(nas_trainer.weighted_log_flops != 0.)