def main(unused_argv): if FLAGS.virtual_gpus > 1: smurf_trainer.set_virtual_gpus_to_at_least(FLAGS.virtual_gpus) if FLAGS.no_tf_function: tf.config.experimental_run_functions_eagerly(True) logging.info('TFFUNCTION DISABLED') logging.info('Parsing gin flags...') gin.parse_config_files_and_bindings(FLAGS.config_file, FLAGS.gin_bindings) smurf_trainer.train_eval()
def test_evaluating_on_spoof(self): FLAGS.eval_on = 'spoof:unused' FLAGS.check_data = False FLAGS.train_on = '' FLAGS.plot_dir = '/tmp/spoof_eval' FLAGS.num_train_steps = 1 FLAGS.evaluate_during_train = True f = io.StringIO() with contextlib.redirect_stdout(f): smurf_trainer.train_eval() # Check that the relevant metrics are printed to stdout. stdout_message = f.getvalue() self.assertIn('spoof-EPE: ', stdout_message) self.assertIn('spoof-occl-f-max: ', stdout_message) self.assertIn('spoof-ER: ', stdout_message) self.assertIn('spoof-best-occl-thresh: ', stdout_message) self.assertIn('spoof-eval-time(s): ', stdout_message) self.assertIn('spoof-inf-time(ms): ', stdout_message)
def test_training_on_spoof(self): FLAGS.eval_on = '' FLAGS.train_on = 'spoof:unused' FLAGS.plot_dir = '/tmp/spoof_train' FLAGS.check_data = True FLAGS.num_train_steps = 1 FLAGS.epoch_length = 1 FLAGS.evaluate_during_train = False f = io.StringIO() with contextlib.redirect_stdout(f): smurf_trainer.train_eval() # Check that the relevant metrics are printed to stdout. stdout_message = f.getvalue() self.assertIn('total-loss: ', stdout_message) self.assertIn('data-time: ', stdout_message) self.assertIn('learning-rate: ', stdout_message) self.assertIn('train-time: ', stdout_message)