def test_shared_variables_over_object_types(self): save_path = tempfile.mkdtemp() system = sample_system_object() shared_var = tf.Variable(initial_value=2, trainable=True) system.traces.append(TestTrace(shared_var)) system.network.ops[0].fe_test_var_1 = shared_var shared_var.assign_add(1) system.save_state(save_dir=save_path) # Re-initialize system = sample_system_object() shared_var = tf.Variable(initial_value=2, trainable=True) system.traces.append(TestTrace(shared_var)) system.network.ops[0].fe_test_var_1 = shared_var system.load_state(load_dir=save_path) with self.subTest("Check variable value was re-loaded"): self.assertEqual(3, system.traces[-1].var1.numpy()) self.assertEqual(3, system.network.ops[0].fe_test_var_1.numpy()) with self.subTest("Check that variable is still shared"): system.traces[-1].var1.assign(5) self.assertEqual(5, system.traces[-1].var1.numpy()) self.assertEqual(5, system.network.ops[0].fe_test_var_1.numpy()) with self.subTest( "Check that variable is still linked to outside code"): shared_var.assign(7) self.assertEqual(7, system.traces[-1].var1.numpy()) self.assertEqual(7, system.network.ops[0].fe_test_var_1.numpy())
def test_restore_training_old_missing(self): system1 = sample_system_object() recorder1 = HistoryRecorder(system=system1, est_path="test.py", db_path=self.db_path) try: with recorder1: print("Test Log Capture") print("Line 2") raise RuntimeError("Training Died") except RuntimeError: pass db = connect(self.db_path) db.execute("DELETE FROM history WHERE pk = (?)", [system1.exp_id]) db.commit() system2 = sample_system_object() recorder2 = HistoryRecorder(system=system2, est_path="test.py", db_path=self.db_path) with recorder2: # Fake a restore wizard system2.__dict__.update(system1.__dict__) print("Line 3") print("Line 4") with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM history") results = cursor.fetchall() with self.subTest("History Captured and Consolidated"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("File name captured"): self.assertEqual(results['file'], 'test.py') with self.subTest("Status updated"): self.assertEqual(results['status'], 'Completed') with self.subTest("Correct PK"): self.assertEqual(results['pk'], system1.exp_id) with self.subTest("Restarts Incremented"): self.assertEqual(results['n_restarts'], 1) # Logs with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM logs WHERE fk = (?)", [system1.exp_id]) results = cursor.fetchall() with self.subTest("Log Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("Complete log captured"): self.assertListEqual(results['log'].splitlines(), ["Line 3", "Line 4"]) db.close()
def test_basic_happy_path(self): labeltracker = LabelTracker(label='y', metric='acc', bounds=None, outputs='out') system = sample_system_object() labeltracker.system = system response = Data() self._simulate_training(labeltracker, response) # Check responses with self.subTest('Check that a response was written'): self.assertIn('out', response) with self.subTest('Check that data was written to the system'): self.assertIn('out', system.custom_graphs) response = response['out'] with self.subTest('Check consistency of outputs'): self.assertEqual(response, system.custom_graphs['out']) with self.subTest('Check that all 3 labels have summaries'): self.assertEqual(3, len(response)) with self.subTest('Check correct mean values (epoch 1)'): self.assertDictEqual( self.epoch_1_means, { elem.name: round(elem.history['train']['acc'][3], 6) for elem in response }) with self.subTest('Check correct mean values (epoch 2)'): self.assertDictEqual( self.epoch_2_means, { elem.name: round(elem.history['train']['acc'][6], 6) for elem in response })
def test_multiple_bounds(self): labeltracker = LabelTracker(label='y', metric='acc', bounds=['std', 'range'], outputs='out') system = sample_system_object() labeltracker.system = system response = Data() self._simulate_training(labeltracker, response) # Check responses with self.subTest('Check that a response was written'): self.assertIn('out', response) with self.subTest('Check that data was written to the system'): self.assertIn('out', system.custom_graphs) response = response['out'] with self.subTest('Check consistency of outputs'): self.assertEqual(response, system.custom_graphs['out']) with self.subTest('Check that all 3 labels have summaries'): self.assertEqual(3, len(response)) with self.subTest('Check that regular mean is not present'): for elem in response: self.assertNotIn('acc', elem.history['train']) with self.subTest('Check that stddev and range are both present'): for elem in response: self.assertIn('acc ($\\mu \\pm \\sigma$)', elem.history['train']) self.assertIn('acc ($min, \\mu, max$)', elem.history['train'])
def test_on_epoch_end(self): data = Data({}) test_essential = TestEssential(monitor_names={'loss'}) test_essential.system = sample_system_object() test_essential.test_results['loss'][''].extend([10, 20]) test_essential.on_epoch_end(data=data) self.assertEqual(data['loss'], 15.0)
def test_basic_happy_path(self): instance_tracker = InstanceTracker(index='idx', metric='ce', outputs='out', mode='train') system = sample_system_object() instance_tracker.system = system response = Data() self._simulate_training(instance_tracker, response) # Check responses with self.subTest('Check that a response was written'): self.assertIn('out', response) with self.subTest('Check that data was written to the system'): self.assertIn('out', system.custom_graphs) response = response['out'] with self.subTest('Check consistency of outputs'): self.assertEqual(response, system.custom_graphs['out']) with self.subTest('Check that 10 indices were tracked'): self.assertEqual(10, len(response)) recorded_indices = {summary.name for summary in response} with self.subTest('Check that 5 min keys were tracked'): min_keys = [3, 4, 8, 11, 14] for key in min_keys: self.assertIn(key, recorded_indices) with self.subTest('Check that 5 max keys were tracked'): max_keys = [7, 8, 24, 25, 27, 28] for key in max_keys: self.assertIn(key, recorded_indices)
def test_specific_indices(self): instance_tracker = InstanceTracker(index='idx', metric='ce', n_max_to_keep=0, n_min_to_keep=0, list_to_keep=[1, 5, 17], outputs='out', mode='train') system = sample_system_object() instance_tracker.system = system response = Data() self._simulate_training(instance_tracker, response) # Check responses with self.subTest('Check that a response was written'): self.assertIn('out', response) with self.subTest('Check that data was written to the system'): self.assertIn('out', system.custom_graphs) response = response['out'] with self.subTest('Check consistency of outputs'): self.assertEqual(response, system.custom_graphs['out']) with self.subTest('Check that 3 indices were tracked'): self.assertEqual(3, len(response)) recorded_indices = {summary.name for summary in response} with self.subTest('Check that the correct keys were tracked'): target_keys = [1, 5, 17] for key in target_keys: self.assertIn(key, recorded_indices)
def test_on_epoch_begin_loss_keys(self): terminate_on_nan = TerminateOnNaN(monitor_names=None) terminate_on_nan.system = sample_system_object() terminate_on_nan.system.network = self.network terminate_on_nan.on_epoch_begin(data={}) self.assertEqual(terminate_on_nan.monitor_keys, self.expected_loss_keys)
def test_tf_on_epoch_end(self): tensorboard = TensorBoard(log_dir=self.log_dir, weight_histogram_freq=1, update_freq=1, write_images='images', write_embeddings='embed', embedding_images='embed_images') tensorboard.system = sample_system_object() tensorboard.system.global_step = 1 tensorboard.writer = _TfWriter(self.log_dir, '', tensorboard.system.network) model = fe.build(model_fn=fe.architecture.tensorflow.LeNet, optimizer_fn='adam') tensorboard.system.network.epoch_models = {model} if os.path.exists(self.train_path): shutil.rmtree(self.train_path) tensorboard.on_epoch_end(data=self.tf_data) tsv_path = os.path.join(self.embed_path, 'tensors.tsv') embed_img_path = os.path.join(self.embed_path, 'sprite.png') # get tensor data from tsv file fo = open(tsv_path) tsv_content = csv.reader(fo, delimiter='\t') for row in tsv_content: tsv_data = row fo.close() # get the image data output_img = np.asarray(Image.open(embed_img_path)) with self.subTest('Check if tensors.tsv was generated'): self.assertTrue(os.path.exists(tsv_path)) with self.subTest('Check if embed image was generated'): self.assertTrue(os.path.exists(embed_img_path)) with self.subTest('Check content of tensors.tsv'): self.assertEqual(tsv_data, 27 * ['1.0']) with self.subTest('Check embed image content'): self.assertTrue(is_equal(output_img, 255 * np.ones(shape=(3, 3, 3), dtype=np.int)))
def setUpClass(cls): save_dir = tempfile.mkdtemp() cls.save_file = os.path.join(save_dir, 'calibrator.pkl') cls.pbm_calibrator = PBMCalibrator(true_key='y', pred_key='y_pred', save_path=cls.save_file) cls.pbm_calibrator.system = sample_system_object()
def test_happy_path(self): system = sample_system_object() recorder = HistoryRecorder(system=system, est_path="test.py", db_path=self.db_path) with recorder: print("Test Log Capture") print("Line 2") db = connect(self.db_path) with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM history WHERE pk = (?)", [system.exp_id]) results = cursor.fetchall() with self.subTest("History Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("File name captured"): self.assertEqual(results['file'], 'test.py') with self.subTest("Status updated"): self.assertEqual(results['status'], 'Completed') # Logs with closing(db.cursor()) as cursor: cursor.execute("SELECT * FROM logs WHERE fk = (?)", [system.exp_id]) results = cursor.fetchall() with self.subTest("Log Captured"): self.assertEqual(len(results), 1) results = results[0] with self.subTest("Complete log captured"): self.assertListEqual(results['log'].splitlines(), ["Test Log Capture", "Line 2"]) db.close()
def instantiate_system(): system = sample_system_object() system.pipeline.ops = [ fe.op.numpyop.meta.Sometimes( TestNumpyOp(inputs="x", outputs="x", mode="train", var=1)) ] return system
def setUpClass(cls): x = np.array([1, 2, 3]) x_pred = np.array([[1, 1, 3], [2, 3, 4], [1, 1, 0]]) cls.data = Data({'x': x, 'x_pred': x_pred}) cls.dice_output = [1.4999999987500001, 2.3999999972, 2.3999999972] cls.dice = Dice(true_key='x', pred_key='x_pred') cls.dice.system = sample_system_object()
def test_on_epoch_end(self): data = Data({}) eval_essential = EvalEssential(monitor_names='loss') eval_essential.system = sample_system_object() eval_essential.eval_results = {'loss': [10, 20]} eval_essential.on_epoch_end(data=data) self.assertEqual(data['loss'], 15.0)
def test_on_epoch_end_mode_test(self): logger = Logger() logger.system = sample_system_object() logger.system.mode = 'test' logger.system.global_step = 2 logger.system.log_steps = 3 self._test_print_msg(func=logger.on_epoch_end, data=self.data, msg=self.on_epoch_end_test_msg)
def test_on_begin_global_step(self): logger = Logger() logger.system = sample_system_object() logger.system.global_step = 2 self._test_print_msg(func=logger.on_begin, data=self.data, msg=self.on_begin_global_step_msg)
def test_max_to_keep_tf_architecture(self): save_dir = tempfile.mkdtemp() model = fe.build(model_fn=one_layer_tf_model, optimizer_fn='adam') model_saver = ModelSaver(model=model, save_dir=save_dir, max_to_keep=2, save_architecture=True) model_saver.system = sample_system_object() model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path1 = os.path.join(save_dir, model_name + '.h5') tf_architecture_path1 = os.path.join(save_dir, model_name) model_saver.system.epoch_idx += 1 model_saver.on_epoch_end(data=Data()) model_name = "{}_epoch_{}".format(model_saver.model.model_name, model_saver.system.epoch_idx) tf_model_path2 = os.path.join(save_dir, model_name + '.h5') tf_architecture_path2 = os.path.join(save_dir, model_name) with self.subTest('Check only four files are kept'): self.assertEqual(len(os.listdir(save_dir)), 4) with self.subTest('Check two latest models are kept'): self.assertTrue(os.path.exists(tf_model_path1)) self.assertTrue(os.path.exists(tf_model_path2)) self.assertTrue(os.path.exists(tf_architecture_path1)) self.assertTrue(os.path.isdir(tf_architecture_path1)) self.assertTrue(os.path.exists(tf_architecture_path2)) self.assertTrue(os.path.isdir(tf_architecture_path2))
def setUpClass(cls): cls.data = Data({'loss': 10}) cls.train_essential = TrainEssential(monitor_names='loss') cls.train_essential.system = sample_system_object() cls.train_essential.system.log_steps = 5 cls.train_essential.system.global_step = 10 cls.train_essential.epoch_start = time.perf_counter() - 500 cls.train_essential.step_start = time.perf_counter() - 300
def instantiate_system(): system = sample_system_object() model = fe.build(model_fn=fe.architecture.tensorflow.LeNet, optimizer_fn='adam', model_name='tf') system.network = fe.Network(ops=[ fe.op.tensorop.meta.Sometimes(TestTensorOp(inputs="x_out", outputs="x_out", mode="train", var=1)), ModelOp(model=model, inputs="x_out", outputs="y_pred") ]) return system
def test_tf_on_begin(self): tensorboard = TensorBoard(log_dir=self.log_dir) tensorboard.system = sample_system_object() tensorboard.system.global_step = 1 with patch('sys.stdout', new=StringIO()) as fake_stdout: tensorboard.on_begin(data=self.tf_data) log = fake_stdout.getvalue().strip() self.assertEqual(log, self.on_begin_msg)
def test_on_batch_end(self): logger = Logger() logger.system = sample_system_object() logger.system.global_step = 1 logger.system.log_steps = 3 self._test_print_msg(func=logger.on_batch_end, data=self.data, msg=self.on_batch_end_msg)
def instantiate_system(): system = sample_system_object() x_train = np.ones((2, 28, 28, 3)) y_train = np.ones((2, )) ds = TestDataset(data={'x': x_train, 'y': y_train}, var=1) train_data = fe.dataset.BatchDataset(datasets=[ds, ds], num_samples=[1, 1]) system.pipeline = fe.Pipeline(train_data=train_data, batch_size=2) return system
def test_on_epoch_begin_all_keys(self): terminate_on_nan = TerminateOnNaN(monitor_names="*") terminate_on_nan.system = sample_system_object() terminate_on_nan.system.mode = "eval" terminate_on_nan.system.network = self.network terminate_on_nan.system.traces = self.traces terminate_on_nan.on_epoch_begin(data={}) self.assertEqual(terminate_on_nan.monitor_keys, self.expected_all_keys)
def instantiate_system(): system = sample_system_object() model = fe.build(model_fn=fe.architecture.pytorch.LeNet, optimizer_fn='adam', model_name='tf') system.network = fe.Network(ops=[ ModelOp(model=model, inputs="x_out", outputs="y_pred"), SuperLoss(CrossEntropy(inputs=['y_pred', 'y'], outputs='ce')) ]) return system
def instantiate_system(): system = sample_system_object() system.pipeline.ops = [ EpochScheduler(epoch_dict={ 1: TestNumpyOp(inputs="x", outputs="x", mode="train", var=1) }) ] return system
def instantiate_system(): tracker = LabelTracker(label='y', metric='acc', bounds=[None, 'range'], outputs='out') system = sample_system_object() system.traces.append(tracker) tracker.system = system return system, tracker
def test_torch_model_on_batch_begin(self): lr_scheduler = LRScheduler(model=self.torch_model, lr_fn=lambda step: fe.schedule.cosine_decay( step, cycle_length=3750, init_lr=1e-3)) lr_scheduler.system = sample_system_object() lr_scheduler.system.global_step = 3 lr_scheduler.on_batch_begin(data=self.data) new_lr = list(self.torch_model.optimizer.param_groups)[0]['lr'] self.assertTrue(math.isclose(new_lr, 0.0009999993, rel_tol=1e-5))
def instantiate_system(): system = sample_system_object() system.pipeline.ops = [ RepeatScheduler([ TestNumpyOp(inputs="x", outputs="x", mode="train", var=1), TestNumpyOp(inputs="x", outputs="x", mode="train", var=1) ]) ] return system
def test_on_epoch_end(self): image_saver = ImageSaver(inputs='img', save_dir=self.image_dir) image_saver.system = sample_system_object() image_saver.on_epoch_end(data=self.data) with self.subTest('Check if image is saved'): self.assertTrue(os.path.exists(self.image_path)) with self.subTest('Check image is valid or not'): im = plt.imread(self.image_path) self.assertFalse(np.any(im[:, 0] == np.nan))
def test_on_epoch_end_early_stopping_msg(self): with patch('sys.stdout', new=StringIO()) as fake_stdout: early_stopping = EarlyStopping(baseline=5.0) early_stopping.system = sample_system_object() early_stopping.system.epoch_idx = 3 early_stopping.best = 2 early_stopping.on_epoch_end(data=self.data) log = fake_stdout.getvalue().strip() self.assertEqual(log, self.expected_msg)