def test(context, model_snapshot, test_path, use_cuda, use_tensorboard, field_names): """Test a trained model snapshot. If model-snapshot is provided, the models and configuration will then be loaded from the snapshot rather than any passed config file. Otherwise, a config file will be loaded. """ model_snapshot, use_cuda, use_tensorboard = _get_model_snapshot( context, model_snapshot, use_cuda, use_tensorboard ) print("\n=== Starting testing...") metric_channels = [] if use_tensorboard: metric_channels.append(TensorBoardChannel()) try: test_model_from_snapshot_path( model_snapshot, use_cuda, test_path, metric_channels, field_names=field_names, ) finally: for mc in metric_channels: mc.close()
def test(context, model_snapshot, test_path, use_cuda, use_tensorboard): """Test a trained model snapshot. If model-snapshot is provided, the models and configuration will then be loaded from the snapshot rather than any passed config file. Otherwise, a config file will be loaded. """ if model_snapshot: print(f"Loading model snapshot and config from {model_snapshot}") if use_cuda is None: raise Exception( "if --model-snapshot is set --use-cuda/--no-cuda must be set") else: print(f"No model snapshot provided, loading from config") config = context.obj.load_config() model_snapshot = config.save_snapshot_path use_cuda = config.use_cuda_if_available print(f"Configured model snapshot {model_snapshot}") print("\n=== Starting testing...") metric_channels = [] if config.use_tensorboard: metric_channels.append(TensorBoardChannel()) try: test_model_from_snapshot_path(model_snapshot, use_cuda, test_path, metric_channels) finally: for mc in metric_channels: mc.close()
def test(context, model_snapshot, test_path, use_cuda, use_tensorboard): """Test a trained model snapshot. If model-snapshot is provided, the models and configuration will then be loaded from the snapshot rather than any passed config file. Otherwise, a config file will be loaded. """ summary_writer = SummaryWriter() if use_tensorboard else None if model_snapshot: print(f"Loading model snapshot and config from {model_snapshot}") if use_cuda is None: raise Exception( "if --model-snapshot is set --use-cuda/--no-cuda must be set" ) else: print(f"No model snapshot provided, loading from config") config = context.obj.load_config() model_snapshot = config.save_snapshot_path use_cuda = config.use_cuda_if_available print(f"Configured model snapshot {model_snapshot}") print("\n=== Starting testing...") try: test_model_from_snapshot_path( model_snapshot, use_cuda, test_path, summary_writer ) finally: if summary_writer is not None: summary_writer.close()
def train(context): """Train a model and save the best snapshot.""" config = parse_config(context.obj.load_config()) print("\n===Starting training...") if config.distributed_world_size == 1: train_model(config) else: train_model_distributed(config) print("\n=== Starting testing...") test_model_from_snapshot_path( config.save_snapshot_path, config.use_cuda_if_available, config.task.data_handler.test_path, )
def test(context, model_snapshot, test_path, use_cuda): """Test a trained model snapshot. If model-snapshot is provided, the models and configuration will then be loaded from the snapshot rather than any passed config file. Otherwise, a config file will be loaded. """ if model_snapshot: print(f"Loading model snapshot and config from {model_snapshot}") if use_cuda is None: raise Exception( "if --model-snapshot is set --use-cuda/--no-cuda must be set") else: print(f"No model snapshot provided, loading from config") config = parse_config(context.obj.load_config()) model_snapshot = config.save_snapshot_path use_cuda = config.use_cuda_if_available print(f"Configured model snapshot {model_snapshot}") print("\n=== Starting testing...") test_model_from_snapshot_path(model_snapshot, use_cuda, test_path)
def train(context): """Train a model and save the best snapshot.""" config = context.obj.load_config() print("\n===Starting training...") summary_writer = SummaryWriter() if config.use_tensorboard else None try: if config.distributed_world_size == 1: train_model(config, summary_writer=summary_writer) else: train_model_distributed(config, summary_writer) print("\n=== Starting testing...") test_model_from_snapshot_path( config.save_snapshot_path, config.use_cuda_if_available, config.task.data_handler.test_path, summary_writer, ) finally: if summary_writer is not None: summary_writer.close()
def train(context): """Train a model and save the best snapshot.""" config = context.obj.load_config() print("\n===Starting training...") metric_channels = [] if config.use_tensorboard: metric_channels.append(TensorBoardChannel()) try: if config.distributed_world_size == 1: train_model(config, metric_channels=metric_channels) else: train_model_distributed(config, metric_channels) print("\n=== Starting testing...") test_model_from_snapshot_path( config.save_snapshot_path, config.use_cuda_if_available, test_path=None, metric_channels=metric_channels, ) finally: for mc in metric_channels: mc.close()