def train( self, training_data: BatchIterator, eval_data: BatchIterator, model: Model, metric_reporter: MetricReporter, train_config: PyTextConfig, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model, the model states will be modified. Args: train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config training_result (Optional): only meaningful for Hogwild training. default is None rank (int): only used in distributed training, the rank of the current training thread, evaluation will only be done in rank 0 Returns: model, best_metric: the trained model together with the best metric """ state = TrainingState( model=model, optimizer=self.optimizer, scheduler=self.scheduler, sparsifier=self.sparsifier, rank=rank, ) return self.train_from_state(state, training_data, eval_data, metric_reporter, train_config)
def test(self, test_iter, model, metric_reporter: MetricReporter): state = TrainingState(stage=Stage.TEST, model=model, epoch=1) if cuda.CUDA_ENABLED: state.model.cuda() state.model.eval() with torch.no_grad(): return self.run_epoch(state, test_iter, metric_reporter)
def update_best_model(self, state: TrainingState, train_config: PyTextConfig, eval_metric): # This should be updated by all workers so they agree on when to stop training # when `early_stop_after` is specified. state.epochs_since_last_improvement = 0 state.best_model_metric = eval_metric print(f"Found a better model!") # Only one worker should save checkpoints if state.rank != 0: return model_state = state.model.state_dict() # save to cpu to avoid multiple model copies in gpu memory if cuda.CUDA_ENABLED: self.move_state_dict_to_cpu(model_state) state.best_model_state = model_state
def test_load_checkpoint(self): with tempfile.NamedTemporaryFile() as checkpoint_file: train_data = tests_module.test_file("train_data_tiny.tsv") eval_data = tests_module.test_file("test_data_tiny.tsv") config = PyTextConfig( task=DocumentClassificationTask.Config(data=Data.Config( source=TSVDataSource.Config( train_filename=train_data, eval_filename=eval_data, field_names=["label", "slots", "text"], ))), version=LATEST_VERSION, save_snapshot_path=checkpoint_file.name, ) task = create_task(config.task) model = task.model # test checkpoint saving and loading optimizer = create_optimizer(Adam.Config(), model) scheduler = create_scheduler(Scheduler.Config(), optimizer) training_state = TrainingState( model=model, optimizer=optimizer, scheduler=scheduler, start_time=0, epoch=0, rank=0, stage=Stage.TRAIN, epochs_since_last_improvement=0, best_model_state=None, best_model_metric=None, tensorizers=None, ) checkpoint_path = checkpoint_file.name save( config, model, None, task.data.tensorizers, training_state, checkpoint_file, ) task_restored, config_restored, training_state_restored = load( checkpoint_path) optimizer_restored = training_state_restored.optimizer scheduler_restored = training_state_restored.scheduler self.assertOptimizerEqual(optimizer, optimizer_restored) self.assertNotNone(scheduler_restored) self.assertEqual(config, config_restored) self.assertModulesEqual(model, task_restored.model) model.eval() task_restored.model.eval() inputs = torch.LongTensor([[1, 2, 3]]), torch.LongTensor([3]) self.assertEqual( model(*inputs).tolist(), task_restored.model(*inputs).tolist())
def set_up_training(self, state: TrainingState, training_data: BatchIterator): if cuda.CUDA_ENABLED: state.model.cuda() state.scheduler.prepare(training_data, self.config.epochs) if cuda.DISTRIBUTED_WORLD_SIZE > 1: device_id = torch.cuda.current_device() state.model = DistributedModel( module=state.model, device_ids=[device_id], output_device=device_id, broadcast_buffers=False, find_unused_parameters=state.model.find_unused_parameters, ) state.start_time = time.time() if self.config.num_batches_per_epoch: # Set the training_data iterator to cycle, so it will never run out, # but rather after reaching the end will loop back to the beginning. training_data = cycle(training_data) return training_data
def test_load_checkpoint_in_dist_training(self): with tempfile.NamedTemporaryFile() as checkpoint_file: train_data = tests_module.test_file("train_data_tiny.tsv") eval_data = tests_module.test_file("test_data_tiny.tsv") config = PyTextConfig( task=DocumentClassificationTask.Config(data=Data.Config( source=BlockShardedTSVDataSource.Config( train_filename=train_data, eval_filename=eval_data, field_names=["label", "slots", "text"], ))), version=LATEST_VERSION, save_snapshot_path=checkpoint_file.name, ) task = create_task(config.task) model = task.model # test checkpoint saving and loading optimizer = create_optimizer(Adam.Config(), model) scheduler = create_scheduler(Scheduler.Config(), optimizer) training_state = TrainingState( model=model, optimizer=optimizer, scheduler=scheduler, start_time=0, epoch=0, rank=0, stage=Stage.TRAIN, epochs_since_last_improvement=0, best_model_state=None, best_model_metric=None, tensorizers=task.data.tensorizers, ) id = "epoch-1" saved_path = save(config, model, None, task.data.tensorizers, training_state, id) new_rank = 2 new_world_size = 4 task_restored, config_restored, training_state_restored = load( saved_path, rank=new_rank, world_size=new_world_size) self.assertCheckpointEqual( model, config, training_state, task_restored.model, config_restored, training_state_restored, ) self.assertEqual(task_restored.data.data_source.rank, new_rank) self.assertEqual(task_restored.data.data_source.world_size, new_world_size)
def train( self, training_data: DataLoader, eval_data: DataLoader, model: Model, optimizer: Optimizer, label_names: List[str], scheduler: Scheduler = None, sparsifier: Sparsifier = None, metric_reporter: MetricReporter = None, train_config: PyTextConfig = None, rank: int = 0, ) -> Tuple[torch.nn.Module, Any]: # temp workaround to minimize changes to TaskTrainer if not train_config: train_config = PyTextConfig( task=NewTask.Config(model=RoBERTa.Config), version=20) if scheduler: self.scheduler = scheduler if sparsifier: self.sparsifier = sparsifier state = TrainingState( model=model, optimizer=optimizer, scheduler=self.scheduler, sparsifier=self.sparsifier, rank=rank, ) metric_reporter_config = ClassificationMetricReporter.Config( output_path="/tmp/test_out.txt", pep_format=False, model_select_metric=ComparableClassificationMetric. ACCURACY, # in json: "accuracy" target_label=None, text_column_names=["text"], additional_column_names=[], recall_at_precision_thresholds=[0.2, 0.4, 0.6, 0.8, 0.9], ) metric_reporter = ClassificationMetricReporter.from_config_and_label_names( config=metric_reporter_config, label_names=label_names) return self.train_from_state(state, training_data, eval_data, metric_reporter, train_config)
def train_from_state( self, state: TrainingState, training_data: BatchIterator, eval_data: BatchIterator, metric_reporter: MetricReporter, train_config: PyTextConfig, ) -> Tuple[torch.nn.Module, Any]: """ Train and eval a model from a given training state will be modified. This function iterates epochs specified in config, and for each epoch do: 1. Train model using training data, aggregate and report training results 2. Adjust learning rate if scheduler is specified 3. Evaluate model using evaluation data 4. Calculate metrics based on evaluation results and select best model Args: training_state (TrainingState): contrains stateful information to be able to restore a training job train_iter (BatchIterator): batch iterator of training data eval_iter (BatchIterator): batch iterator of evaluation data model (Model): model to be trained metric_reporter (MetricReporter): compute metric based on training output and report results to console, file.. etc train_config (PyTextConfig): training config Returns: model, best_metric: the trained model together with the best metric """ training_data = self.set_up_training(state, training_data) model = state.model rank = state.rank trainable_params = sum(p.numel() for p in state.model.parameters() if p.requires_grad) print(f"Model :{model}") print(f"Num trainable parameters: {trainable_params}") self.sparsifier.initialize(self, state, eval_data, metric_reporter, train_config) while self.continue_training(state): self.sparsifier.op_pre_epoch(self, state) state.epoch += 1 state.epochs_since_last_improvement += 1 lrs = learning_rates(state.optimizer) print(f"\nWorker {state.rank} starting epoch {state.epoch}") print(f"Learning rate(s): {', '.join(map(str, lrs))}") with timing.time("train epoch"): state.stage = Stage.TRAIN state.model.train() print(f"start training epoch {state.epoch}") epoch_data = training_data if self.config.num_batches_per_epoch: # We want to limit the number of batches in the epoch; # equivalent to epoch_data[:num_batches_per_epoch] for iterators. # In this case we set the training data iterator to cycle earlier # in the training process, so when it reaches the end it will # loop back to the beginning. epoch_data = itertools.islice( epoch_data, self.config.num_batches_per_epoch) self.run_epoch(state, epoch_data, metric_reporter) if not self.config.do_eval: continue with timing.time("eval epoch"): state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating epoch {state.epoch}") with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) # Step the learning rate scheduler(s) assert eval_metric is not None state.scheduler.step_epoch( metrics=metric_reporter.get_model_select_metric(eval_metric), epoch=state.epoch, ) # Did we train a better model? better_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if better_model: self.update_best_model(state, train_config, eval_metric) if better_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) if self.optimizer.finalize(): should_update_model = True eval_metric = None if self.config.do_eval: state.stage = Stage.EVAL model.eval(Stage.EVAL) print(f"start evaluating finalized state") with torch.no_grad(): eval_metric = self.run_epoch(state, eval_data, metric_reporter) should_update_model = metric_reporter.compare_metric( eval_metric, state.best_model_metric) if should_update_model: self.update_best_model(state, train_config, eval_metric) if should_update_model or train_config.save_all_checkpoints: self.save_checkpoint(state, train_config) # Only bother loading the best model for master worker if (rank == 0 and state.best_model_state is not None and self.config.load_best_model_after_train): self.load_best_model(state) return state.model, state.best_model_metric
def test(self, test_iter, model, metric_reporter: MetricReporter): state = TrainingState(stage=Stage.TEST, model=model, epoch=1) if cuda.CUDA_ENABLED: state.model.cuda() state.model.eval() return self.eval_from_state(state, test_iter, metric_reporter)