示例#1
0
def main(unused_argv):
    if FLAGS.virtual_gpus > 1:
        smurf_trainer.set_virtual_gpus_to_at_least(FLAGS.virtual_gpus)

    if FLAGS.no_tf_function:
        tf.config.experimental_run_functions_eagerly(True)
        logging.info('TFFUNCTION DISABLED')

    logging.info('Parsing gin flags...')
    gin.parse_config_files_and_bindings(FLAGS.config_file, FLAGS.gin_bindings)

    smurf_trainer.train_eval()
    def test_evaluating_on_spoof(self):
        FLAGS.eval_on = 'spoof:unused'
        FLAGS.check_data = False
        FLAGS.train_on = ''
        FLAGS.plot_dir = '/tmp/spoof_eval'
        FLAGS.num_train_steps = 1
        FLAGS.evaluate_during_train = True
        f = io.StringIO()
        with contextlib.redirect_stdout(f):
            smurf_trainer.train_eval()

        # Check that the relevant metrics are printed to stdout.
        stdout_message = f.getvalue()
        self.assertIn('spoof-EPE: ', stdout_message)
        self.assertIn('spoof-occl-f-max: ', stdout_message)
        self.assertIn('spoof-ER: ', stdout_message)
        self.assertIn('spoof-best-occl-thresh: ', stdout_message)
        self.assertIn('spoof-eval-time(s): ', stdout_message)
        self.assertIn('spoof-inf-time(ms): ', stdout_message)
    def test_training_on_spoof(self):
        FLAGS.eval_on = ''
        FLAGS.train_on = 'spoof:unused'
        FLAGS.plot_dir = '/tmp/spoof_train'
        FLAGS.check_data = True
        FLAGS.num_train_steps = 1
        FLAGS.epoch_length = 1
        FLAGS.evaluate_during_train = False

        f = io.StringIO()
        with contextlib.redirect_stdout(f):
            smurf_trainer.train_eval()

        # Check that the relevant metrics are printed to stdout.
        stdout_message = f.getvalue()
        self.assertIn('total-loss: ', stdout_message)
        self.assertIn('data-time: ', stdout_message)
        self.assertIn('learning-rate: ', stdout_message)
        self.assertIn('train-time: ', stdout_message)