def reward(state, dataset_handler, log_path="./workspace"): """Perform the training of the network, given (state, dataset) pair.""" try: train_features, train_labels = dataset_handler.current_train_set() val_features, val_labels = dataset_handler.current_validation_set() hash_state = compute_str_hash(state_to_string(state)) nas_trainer = EarlyStopNASTrainer( encoded_network=state, input_shape=infer_data_shape(train_features), n_classes=infer_n_classes(train_labels), batch_size=256, log_path="{lp}/trainer-{h}".format(lp=log_path, h=hash_state), mu=0.5, rho=0.5, variable_scope="cnn-{h}".format(h=hash_state)) train_features = normalize_dataset(dataset=train_features, baseline=255) train_labels = train_labels.astype(np.int32) nas_trainer.train( train_data=train_features, train_labels=train_labels, train_input_fn="default", n_epochs=12 # As specified by BlockQNN ) val_features = normalize_dataset(dataset=val_features, baseline=255) val_labels = val_labels.astype(np.int32) res = nas_trainer.evaluate(eval_data=val_features, eval_labels=val_labels, eval_input_fn="default") accuracy = res['accuracy'] # Compute the refined reward as defined reward = accuracy*100 - nas_trainer.weighted_log_density - \ nas_trainer.weighted_log_flops return reward except Exception as ex: # pylint: disable=broad-except # TODO: Make sure exceptions are printed correctly. print("Reward computation failed with exception:", ex) return 0.
def test_train(self): """Test the Default Training procedure.""" tf.reset_default_graph() if os.path.isdir(self.training_dir): shutil.rmtree(self.training_dir) nas_trainer = DefaultNASTrainer( encoded_network=self.net_nsc, input_shape=infer_data_shape(self.train_data), n_classes=infer_n_classes(self.train_labels), batch_size=self.batch_size, log_path=self.training_dir, variable_scope="cnn") nas_trainer.train(train_data=self.train_data, train_labels=self.train_labels, train_input_fn="default") self.assertTrue(os.path.isdir(self.training_dir))
def test_evaluate(self): """Test the Default Training procedure.""" tf.reset_default_graph() if os.path.isdir(self.training_dir): shutil.rmtree(self.training_dir) nas_trainer = EarlyStopNASTrainer( encoded_network=self.net_nsc, input_shape=infer_data_shape(self.train_data), n_classes=infer_n_classes(self.train_labels), batch_size=self.batch_size, log_path=self.training_dir, mu=0.5, rho=0.5, variable_scope="cnn") nas_trainer.train(train_data=self.train_data, train_labels=self.train_labels, train_input_fn="default") res = nas_trainer.evaluate(eval_data=self.eval_data, eval_labels=self.eval_labels, eval_input_fn="default") self.assertTrue(os.path.isdir(self.training_dir)) self.assertTrue("accuracy" in list(res.keys())) self.assertTrue(nas_trainer.density is not None) self.assertTrue(nas_trainer.density != 0.) self.assertTrue(nas_trainer.weighted_log_density is not None) self.assertTrue(nas_trainer.weighted_log_density != 0.) self.assertTrue(nas_trainer.flops is not None) self.assertTrue(nas_trainer.flops != 0.) self.assertTrue(nas_trainer.weighted_log_flops is not None) self.assertTrue(nas_trainer.weighted_log_flops != 0.)