def test_checkpointing(self): # make checkpoint directory checkpoint_folder = self.base_dir + "/checkpoint/" os.mkdir(checkpoint_folder) config = get_fast_test_task_config() cuda_available = torch.cuda.is_available() task = build_task(config) task.prepare(use_gpu=cuda_available) # create a checkpoint hook checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"]) # call the on end phase function checkpoint_hook.on_phase_end(task) # we should be able to train a task using the checkpoint on all available # devices for use_gpu in {False, cuda_available}: # load the checkpoint checkpoint = load_checkpoint(checkpoint_folder) # create a new task task = build_task(config) # set the checkpoint task.set_checkpoint(checkpoint) task.prepare(use_gpu=use_gpu) # we should be able to run the trainer using the checkpoint trainer = LocalTrainer(use_gpu=use_gpu) trainer.train(task)
def test_final_train_checkpoint(self): """Test that a train phase checkpoint with a where of 1.0 can be loaded""" config = get_fast_test_task_config() task = build_task(config).set_hooks( [CheckpointHook(self.base_dir, {}, phase_types=["train"])]) task_2 = build_task(config) use_gpu = torch.cuda.is_available() trainer = LocalTrainer(use_gpu=use_gpu) trainer.train(task) # load the final train checkpoint checkpoint = load_checkpoint(self.base_dir) # make sure fetching the where raises an exception, which means that # where is >= 1.0 with self.assertRaises(Exception): task.where # set task_2's state as task's final train checkpoint task_2.set_checkpoint(checkpoint) task_2.prepare(use_gpu=use_gpu) # we should be able to train the task trainer.train(task_2)
def train(datasets, model, loss, optimizer, meters, args): task = (ClassificationTask() .set_num_epochs(args.num_epochs) .set_loss(loss) .set_model(model) .set_optimizer(optimizer) .set_meters(meters)) for phase in ["train", "test"]: task.set_dataset(datasets[phase], phase) hooks = [LossLrMeterLoggingHook(log_freq=args.print_freq)] # show progress hooks.append(ProgressBarHook()) if not args.skip_tensorboard: try: from tensorboardX import SummaryWriter tb_writer = SummaryWriter(log_dir=args.video_dir + "/tensorboard") hooks.append(TensorboardPlotHook(tb_writer)) except ImportError: print("tensorboardX not installed, skipping tensorboard hooks") checkpoint_dir = f"{args.video_dir}/checkpoint/classy_checkpoint_{time.time()}" os.mkdir(checkpoint_dir) hooks.append(CheckpointHook(checkpoint_dir, input_args={})) task = task.set_hooks(hooks) trainer = LocalTrainer(use_gpu=args.cuda, num_dataloader_workers=args.num_workers) trainer.train(task)
def main(local_rank, c10d_backend, rdzv_init_url, max_world_size, classy_args): torch.manual_seed(0) set_video_backend(classy_args.video_backend) # Loads config, sets up task config = load_json(classy_args.config_file) task = build_task(config) # Load checkpoint, if available checkpoint = load_checkpoint(classy_args.checkpoint_folder) task.set_checkpoint(checkpoint) pretrained_checkpoint = load_checkpoint(classy_args.pretrained_checkpoint_folder) if pretrained_checkpoint is not None: assert isinstance( task, FineTuningTask ), "Can only use a pretrained checkpoint for fine tuning tasks" task.set_pretrained_checkpoint(pretrained_checkpoint) hooks = [ LossLrMeterLoggingHook(classy_args.log_freq), ModelComplexityHook(), TimeMetricsHook(), ] if classy_args.checkpoint_folder != "": args_dict = vars(classy_args) args_dict["config"] = config hooks.append( CheckpointHook( classy_args.checkpoint_folder, args_dict, checkpoint_period=classy_args.checkpoint_period, ) ) if classy_args.profiler: hooks.append(ProfilerHook()) task.set_hooks(hooks) assert c10d_backend == Backend.NCCL or c10d_backend == Backend.GLOO if c10d_backend == torch.distributed.Backend.NCCL: # needed to enable NCCL error handling os.environ["NCCL_BLOCKING_WAIT"] = "1" coordinator = CoordinatorP2P( c10d_backend=c10d_backend, init_method=rdzv_init_url, max_num_trainers=max_world_size, process_group_timeout=60000, ) trainer = ElasticTrainer( use_gpu=classy_args.device == "gpu", num_dataloader_workers=classy_args.num_workers, local_rank=local_rank, elastic_coordinator=coordinator, input_args={}, ) trainer.train(task)
def test_failure(self) -> None: self.assertFalse(PathManager.exists("test://foo")) PathManager.register_handler(TestPathHandler()) # make sure that TestPathHandler is being used self.assertTrue(PathManager.exists("test://foo")) checkpoint_folder = "test://root" checkpoint_hook = CheckpointHook(checkpoint_folder, {}, phase_types=["train"]) config = get_test_task_config() task = build_task(config) task.prepare() # we should raise an exception while trying to save the checkpoint with self.assertRaises(TestException): checkpoint_hook.on_phase_end(task)
def test_from_checkpoint(self): config = get_test_task_config() for use_head in [True, False]: config["model"] = self.get_model_config(use_head) task = build_task(config) task.prepare() checkpoint_folder = f"{self.base_dir}/{use_head}/" input_args = {"config": config} # Simulate training by setting the model parameters to zero for param in task.model.parameters(): param.data.zero_() checkpoint_hook = CheckpointHook( checkpoint_folder, input_args, phase_types=["train"] ) # Create checkpoint dir, save checkpoint os.mkdir(checkpoint_folder) checkpoint_hook.on_start(task) task.train = True checkpoint_hook.on_phase_end(task) # Model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) model = ClassyModel.from_checkpoint(checkpoint) self.assertTrue(isinstance(model, MyTestModel)) # All parameters must be zero for param in model.parameters(): self.assertTrue(torch.all(param.data == 0))
def test_checkpoint_period(self) -> None: """ Test that the checkpoint_period works as expected. """ config = get_test_task_config() task = build_task(config) task.prepare() local_variables = {} checkpoint_folder = self.base_dir + "/checkpoint_end_test/" checkpoint_period = 10 for phase_types in [["train"], ["train", "test"]]: # create a checkpoint hook checkpoint_hook = CheckpointHook( checkpoint_folder, {}, phase_types=phase_types, checkpoint_period=checkpoint_period, ) # create checkpoint dir os.mkdir(checkpoint_folder) # call the on start function checkpoint_hook.on_start(task) # shouldn't create any checkpoints until there are checkpoint_period # phases which are in phase_types count = 0 valid_phase_count = 0 while valid_phase_count < checkpoint_period - 1: task.train = count % 2 == 0 # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) valid_phase_count += 1 if task.phase_type in phase_types else 0 count += 1 # create a phase which is in phase_types task.train = True # call the on end phase function checkpoint_hook.on_phase_end(task, local_variables) # model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNotNone(checkpoint) # delete the checkpoint dir shutil.rmtree(checkpoint_folder)
def test_final_train_checkpoint(self): """Test that a train phase checkpoint with a where of 1.0 can be loaded""" config = get_fast_test_task_config() task = build_task(config).set_hooks( [CheckpointHook(self.base_dir, {}, phase_types=["train"])]) task_2 = build_task(config) task.set_use_gpu(torch.cuda.is_available()) trainer = LocalTrainer() trainer.train(task) self.assertAlmostEqual(task.where, 1.0, delta=1e-3) # set task_2's state as task's final train checkpoint task_2.set_checkpoint(self.base_dir) task_2.prepare() # we should be able to train the task trainer.train(task_2)
def configure_hooks(args, config): hooks = [LossLrMeterLoggingHook(args.log_freq), ModelComplexityHook()] # Make a folder to store checkpoints and tensorboard logging outputs suffix = datetime.now().isoformat() base_folder = f"{Path(__file__).parent}/output_{suffix}" if args.checkpoint_folder == "": args.checkpoint_folder = base_folder + "/checkpoints" os.makedirs(args.checkpoint_folder, exist_ok=True) logging.info(f"Logging outputs to {base_folder}") logging.info(f"Logging checkpoints to {args.checkpoint_folder}") if not args.skip_tensorboard: try: from torch.utils.tensorboard import SummaryWriter os.makedirs(Path(base_folder) / "tensorboard", exist_ok=True) tb_writer = SummaryWriter(log_dir=Path(base_folder) / "tensorboard") hooks.append(TensorboardPlotHook(tb_writer)) except ImportError: logging.warning( "tensorboard not installed, skipping tensorboard hooks") args_dict = vars(args) args_dict["config"] = config hooks.append( CheckpointHook(args.checkpoint_folder, args_dict, checkpoint_period=args.checkpoint_period)) if args.profiler: hooks.append(ProfilerHook()) if args.show_progress: hooks.append(ProgressBarHook()) if args.visdom_server != "": hooks.append(VisdomHook(args.visdom_server, args.visdom_port)) return hooks
def test_state_checkpointing(self) -> None: """ Test that the state gets checkpointed without any errors, but only on the right phase_type and only if the checkpoint directory exists. """ config = get_test_task_config() task = build_task(config) task.prepare() checkpoint_folder = self.base_dir + "/checkpoint_end_test/" input_args = {"foo": "bar"} # create a checkpoint hook checkpoint_hook = CheckpointHook(checkpoint_folder, input_args, phase_types=["train"]) # checkpoint directory doesn't exist # call the on start function with self.assertRaises(FileNotFoundError): checkpoint_hook.on_start(task) # call the on end phase function with self.assertRaises(AssertionError): checkpoint_hook.on_phase_end(task) # try loading a non-existent checkpoint checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) # create checkpoint dir, verify on_start hook runs os.mkdir(checkpoint_folder) checkpoint_hook.on_start(task) # Phase_type is test, expect no checkpoint task.train = False # call the on end phase function checkpoint_hook.on_phase_end(task) checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNone(checkpoint) task.train = True # call the on end phase function checkpoint_hook.on_phase_end(task) # model should be checkpointed. load and compare checkpoint = load_checkpoint(checkpoint_folder) self.assertIsNotNone(checkpoint) for key in ["input_args", "classy_state_dict"]: self.assertIn(key, checkpoint) # not testing for equality of classy_state_dict, that is tested in # a separate test self.assertDictEqual(checkpoint["input_args"], input_args)