def build_meters(run_type: str) -> List[Meter]: train_meter, val_meter, test_meter = None, None, None if "train" in run_type: train_meter = Meter() # val_meter used for validation after training loop val_meter = Meter() elif "val" in run_type or "inference" in run_type: val_meter = Meter() if "test" in run_type: test_meter = Meter() return train_meter, val_meter, test_meter
class TrainerReportingMixin(ABC): meter: Type[Meter] = Meter() def update_meter(self, report: Dict[str, Any], meter: Type[Meter] = None, eval_mode: bool = False) -> None: if meter is None: meter = self.meter if hasattr(report, "metrics"): metrics_dict = report.metrics reduced_metrics_dict = reduce_dict(metrics_dict) if not eval_mode: loss_dict = report.losses reduced_loss_dict = reduce_dict(loss_dict) with torch.no_grad(): # Add metrics to meter only when mode is `eval` meter_update_dict = {} if not eval_mode: loss_key = report.dataset_type + "/total_loss" reduced_loss = sum( [loss.mean() for loss in reduced_loss_dict.values()]) if hasattr(reduced_loss, "item"): reduced_loss = reduced_loss.item() registry.register(loss_key, reduced_loss) meter_update_dict.update({loss_key: reduced_loss}) meter_update_dict.update(reduced_loss_dict) if hasattr(report, "metrics"): meter_update_dict.update(reduced_metrics_dict) meter.update(meter_update_dict, report.batch_size)
def evaluate(self, loader, use_tqdm=False, single_batch=False): meter = Meter() with torch.no_grad(): self.model.eval() disable_tqdm = not use_tqdm or not is_master() combined_report = None for batch in tqdm(loader, disable=disable_tqdm): report = self._forward_pass(batch) self._update_meter(report, meter) # accumulate necessary params for metric calculation if combined_report is None: combined_report = report else: combined_report.accumulate_tensor_fields( report, self.metrics.required_params) combined_report.batch_size += report.batch_size if single_batch is True: break combined_report.metrics = self.metrics(combined_report, combined_report) self._update_meter(combined_report, meter, eval_mode=True) self.model.train() return combined_report, meter
def evaluation_loop( self, loader, use_tqdm: bool = False, single_batch: bool = False) -> Tuple[Dict[str, Any], Type[Meter]]: meter = Meter() with torch.no_grad(): self.model.eval() disable_tqdm = not use_tqdm or not is_master() combined_report = None for batch in tqdm.tqdm(loader, disable=disable_tqdm): report = self._forward(batch) self.update_meter(report, meter) # accumulate necessary params for metric calculation if combined_report is None: combined_report = report else: combined_report.accumulate_tensor_fields( report, self.metrics.required_params) combined_report.batch_size += report.batch_size if single_batch is True: break combined_report.metrics = self.metrics(combined_report, combined_report) self.update_meter(combined_report, meter, eval_mode=True) # enable train mode again self.model.train() return combined_report, meter
def evaluation_loop( self, dataset_type: str, use_tqdm: bool = False, single_batch: bool = False) -> Tuple[Dict[str, Any], Type[Meter]]: meter = Meter() reporter = self.dataset_loader.get_test_reporter(dataset_type) with torch.no_grad(): self.model.eval() disable_tqdm = not use_tqdm or not is_master() while reporter.next_dataset(flush_report=False): dataloader = reporter.get_dataloader() combined_report = None for batch in tqdm.tqdm(dataloader, disable=disable_tqdm): prepared_batch = reporter.prepare_batch(batch) prepared_batch = to_device(prepared_batch, self.device) model_output = self.model(prepared_batch) report = Report(prepared_batch, model_output) self.update_meter(report, meter) # accumulate necessary params for metric calculation if combined_report is None: # make a copy of report since `reporter.add_to_report` will # change some of the report keys later combined_report = Report(report) else: combined_report.accumulate_tensor_fields_and_loss( report, self.metrics.required_params) combined_report.batch_size += report.batch_size # Each node generates a separate copy of predict JSON from the report, # which will be used to evaluate dataset-level metrics # (such as mAP in object detection or CIDEr in image captioning) # Since `reporter.add_to_report` changes report keys (e.g. scores), # do this after `combined_report.accumulate_tensor_fields_and_loss` if "__prediction_report__" in self.metrics.required_params: reporter.add_to_report(report, self.model, execute_on_master_only=False) if single_batch is True: break reporter.postprocess_dataset_report() # add prediction_report is used for set-level metrics combined_report.prediction_report = reporter.report combined_report.metrics = self.metrics(combined_report, combined_report) self.update_meter(combined_report, meter, eval_mode=True) # enable train mode again self.model.train() return combined_report, meter
def load_extras(self): self.writer.write("Torch version is: " + torch.__version__) self.checkpoint = Checkpoint(self) self.meter = Meter() self.training_config = self.config.training early_stop_criteria = self.training_config.early_stop.criteria early_stop_minimize = self.training_config.early_stop.minimize early_stop_enabled = self.training_config.early_stop.enabled early_stop_patience = self.training_config.early_stop.patience self.log_interval = self.training_config.log_interval self.evaluation_interval = self.training_config.evaluation_interval self.checkpoint_interval = self.training_config.checkpoint_interval self.max_updates = self.training_config.max_updates self.should_clip_gradients = self.training_config.clip_gradients self.max_epochs = self.training_config.max_epochs self.early_stopping = EarlyStopping( self.model, self.checkpoint, early_stop_criteria, patience=early_stop_patience, minimize=early_stop_minimize, should_stop=early_stop_enabled, ) self.current_epoch = 0 self.current_iteration = 0 self.num_updates = 0 self.checkpoint.load_state_dict() self.not_debug = self.training_config.logger_level != "debug" self.lr_scheduler = None if self.training_config.lr_scheduler is True: self.lr_scheduler = build_scheduler(self.optimizer, self.config) self.tb_writer = None if self.training_config.tensorboard: log_dir = self.writer.log_dir env_tb_logdir = get_mmf_env(key="tensorboard_logdir") if env_tb_logdir: log_dir = env_tb_logdir self.tb_writer = TensorboardLogger(log_dir, self.current_iteration)
def test_meter_update_from_report(self): meter = Meter() prepared_batch = SampleList( {"targets": torch.tensor([1, 2, 3, 4]), "dataset_type": "val"} ) for idx in range(5): model_output = { "scores": torch.tensor([0, 1, 2, 3]), "losses": {"loss": float(idx)}, } report = Report(prepared_batch, model_output) meter.update_from_report(report) self.assertEqual(meter.loss.global_avg, 2.0) self.assertEqual(meter.loss.avg, 2.0)
class TrainerReportingMixin(ABC): meter: Type[Meter] = Meter() def update_meter(self, report: Dict[str, Any], meter: Type[Meter] = None, eval_mode: bool = False) -> None: if meter is None: meter = self.meter if hasattr(report, "metrics"): metrics_dict = report.metrics reduced_metrics_dict = reduce_dict(metrics_dict) if not eval_mode: loss_dict = report.losses reduced_loss_dict = reduce_dict(loss_dict) with torch.no_grad(): # Add metrics to meter only when mode is `eval` meter_update_dict = {} if not eval_mode: total_loss_key = report.dataset_type + "/total_loss" meter_update_dict, total_loss = self.update_dict( meter_update_dict, reduced_loss_dict) registry.register(total_loss_key, total_loss) meter_update_dict.update({total_loss_key: total_loss}) if hasattr(report, "metrics"): meter_update_dict, _ = self.update_dict( meter_update_dict, reduced_metrics_dict) meter.update(meter_update_dict, report.batch_size) def update_dict(self, meter_update_dict, values_dict): total_val = 0 for key, val in values_dict.items(): if torch.is_tensor(val): if val.dim() == 1: val = val.mean() if hasattr(val, "item"): val = val.item() meter_update_dict.update({key: val}) total_val += val return meter_update_dict, total_val
def setUp(self): self.tmpdir = tempfile.mkdtemp() self.trainer = argparse.Namespace() self.config = load_yaml(os.path.join("configs", "defaults.yaml")) self.config = OmegaConf.merge( self.config, { "model": "simple", "model_config": {}, "training": { "checkpoint_interval": 1, "evaluation_interval": 10, "early_stop": { "criteria": "val/total_loss" }, "batch_size": 16, "log_interval": 10, "logger_level": "info", }, "env": { "save_dir": self.tmpdir }, }, ) # Keep original copy for testing purposes self.trainer.config = deepcopy(self.config) registry.register("config", self.trainer.config) setup_logger() self.report = Mock(spec=Report) self.report.dataset_name = "abcd" self.report.dataset_type = "test" self.trainer.model = SimpleModule() self.trainer.val_loader = torch.utils.data.DataLoader( NumbersDataset(), batch_size=self.config.training.batch_size) self.trainer.optimizer = torch.optim.Adam( self.trainer.model.parameters(), lr=1e-01) self.trainer.device = "cpu" self.trainer.num_updates = 0 self.trainer.current_iteration = 0 self.trainer.current_epoch = 0 self.trainer.max_updates = 0 self.trainer.meter = Meter() self.cb = LogisticsCallback(self.config, self.trainer)
def evaluation_loop( self, dataset_type: str, use_tqdm: bool = False, single_batch: bool = False) -> Tuple[Dict[str, Any], Type[Meter]]: meter = Meter() reporter = self.dataset_loader.get_test_reporter(dataset_type) use_cpu = self.config.evaluation.get("use_cpu", False) loaded_batches = 0 skipped_batches = 0 with torch.no_grad(): self.model.eval() disable_tqdm = not use_tqdm or not is_master() while reporter.next_dataset(flush_report=False): dataloader = reporter.get_dataloader() combined_report = None if self._can_use_tqdm(dataloader): dataloader = tqdm.tqdm(dataloader, disable=disable_tqdm) for batch in dataloader: # Do not timeout quickly on first batch, as workers might start at # very different times. with CompleteInTimeOrDie(600 if loaded_batches else 3600 * 24): loaded_batches += 1 prepared_batch = reporter.prepare_batch(batch) prepared_batch = to_device(prepared_batch, self.device) if not validate_batch_sizes( prepared_batch.get_batch_size()): logger.info( "Skip batch due to uneven batch sizes.") skipped_batches += 1 continue model_output = self.model(prepared_batch) report = Report(prepared_batch, model_output) report = report.detach() meter.update_from_report(report) moved_report = report # Move to CPU for metrics calculation later if needed # Explicitly use `non_blocking=False` as this can cause # race conditions in next accumulate if use_cpu: moved_report = report.copy().to("cpu", non_blocking=False) # accumulate necessary params for metric calculation if combined_report is None: # make a copy of report since `reporter.add_to_report` will # change some of the report keys later combined_report = moved_report.copy() else: combined_report.accumulate_tensor_fields_and_loss( moved_report, self.metrics.required_params) combined_report.batch_size += moved_report.batch_size # Each node generates a separate copy of predict JSON from the # report, which will be used to evaluate dataset-level metrics # (such as mAP in object detection or CIDEr in image captioning) # Since `reporter.add_to_report` changes report keys, # (e.g scores) do this after # `combined_report.accumulate_tensor_fields_and_loss` if "__prediction_report__" in self.metrics.required_params: # Still need to use original report here on GPU/TPU since # it will be gathered reporter.add_to_report( report, self.model, execute_on_master_only=False) if single_batch is True: break logger.info(f"Finished training. Loaded {loaded_batches}") logger.info(f" -- skipped {skipped_batches} batches.") reporter.postprocess_dataset_report() assert (combined_report is not None ), "Please check if your validation set is empty!" # add prediction_report is used for set-level metrics combined_report.prediction_report = reporter.report combined_report.metrics = self.metrics(combined_report, combined_report) # Since update_meter will reduce the metrics over GPUs, we need to # move them back to GPU but we will only move metrics and losses # which are needed by update_meter to avoid OOM # Furthermore, do it in a non_blocking way to avoid any issues # in device to host or host to device transfer if use_cpu: combined_report = combined_report.to( self.device, fields=["metrics", "losses"], non_blocking=False) meter.update_from_report(combined_report, should_update_loss=False) # enable train mode again self.model.train() return combined_report, meter
class TrainerTrainingLoopMixin(ABC): current_epoch: int = 0 current_iteration: int = 0 num_updates: int = 0 meter: Meter = Meter() def training_loop(self) -> None: self.max_updates = self._calculate_max_updates() torch.autograd.set_detect_anomaly(self.training_config.detect_anomaly) logger.info("Starting training...") self.model.train() self.run_training_epoch() self.after_training_loop() def after_training_loop(self) -> None: logger.info("Stepping into final validation check") # Only do when run_type has train as it shouldn't happen on validation and # inference runs. Inference will take care of this anyways. Also, don't run # if current iteration is divisble by snapshot interval as it will just # be a repeat if ("train" in self.run_type and "val" in self.run_type and self.num_updates % self.training_config.evaluation_interval != 0): # Create a new meter for this case report, meter = self.evaluation_loop("val") # Validation end callbacks self.on_validation_end(report=report, meter=meter) def run_training_epoch(self) -> None: should_break = False while self.num_updates < self.max_updates and not should_break: self.current_epoch += 1 registry.register("current_epoch", self.current_epoch) # Seed the sampler in case if it is distributed self.dataset_loader.seed_sampler("train", self.current_epoch) # For iterable datasets we cannot determine length of dataset properly. # For those cases we set num_remaining_batches to be the (number of # updates remaining x update_frequency) num_remaining_batches = (((self.max_updates - self.num_updates) * self.training_config.update_frequency) if isinstance( self.train_loader.current_dataset, torch.utils.data.IterableDataset) else len(self.train_loader)) should_start_update = True for idx, batch in enumerate(self.train_loader): if should_start_update: combined_report = None self._start_update() num_batches_for_this_update = min( self.training_config.update_frequency, num_remaining_batches) should_start_update = False self.current_iteration += 1 # batch execution starts here self.on_batch_start() self.profile("Batch load time") report = self.run_training_batch(batch, num_batches_for_this_update) report = report.detach() # accumulate necessary params (including loss) for metric calculation if combined_report is None: combined_report = report else: combined_report.accumulate_tensor_fields_and_loss( report, self.metrics.required_params) combined_report.batch_size += report.batch_size # batch execution ends here self.on_batch_end(report=combined_report, meter=self.meter) # check if an update has finished or if it is the last, if no continue if ((idx + 1) % self.training_config.update_frequency and num_remaining_batches != num_batches_for_this_update): continue self._finish_update() should_start_update = True should_log = False if self.num_updates % self.logistics_callback.log_interval == 0: should_log = True # Calculate metrics every log interval for debugging if self.training_config.evaluate_metrics: combined_report.metrics = self.metrics( combined_report, combined_report) self.meter.update_from_report(combined_report) self.on_update_end(report=combined_report, meter=self.meter, should_log=should_log) num_remaining_batches -= num_batches_for_this_update # Check if training should be stopped should_break = False if self.num_updates % self.training_config.evaluation_interval == 0: # Validation begin callbacks self.on_validation_start() logger.info( "Evaluation time. Running on full validation set...") # Validation and Early stopping # Create a new meter for this case report, meter = self.evaluation_loop("val") # Validation end callbacks stop = self.early_stop_callback.on_validation_end( report=report, meter=meter) self.on_validation_end(report=report, meter=meter) gc.collect() if "cuda" in str(self.device): torch.cuda.empty_cache() if stop is True: logger.info("Early stopping activated") should_break = True if self.num_updates >= self.max_updates: should_break = True if should_break: break def run_training_batch(self, batch: Dict[str, Tensor], loss_divisor: int) -> None: report = self._forward(batch) if self.training_config.exit_on_nan_losses: self._check_nan_losses(report) loss = extract_loss(report, loss_divisor) self._backward(loss) return report def _check_nan_losses(self, report): # skip this check in XLA mode as calling .item() in forward pass # greatly slows down the training if not is_xla(): # check whether NaN has occurred in the losses, and exit the training # when NaN happens loss_dict = report.losses nan_loss_keys = [] for key, value in loss_dict.items(): if torch.any(torch.isnan(value)).item(): nan_loss_keys.append(key) if len(nan_loss_keys) > 0: keys_str = ", ".join(nan_loss_keys) error_msg = ( f"NaN occurred in the following loss(es): {keys_str}; " f"exiting the training") logger.info(error_msg) raise RuntimeError(error_msg) def _forward(self, batch: Dict[str, Tensor]) -> Dict[str, Any]: # Move the sample list to device if it isn't as of now. prepared_batch = to_device(batch, self.device) self.profile("Batch prepare time") # Arguments should be a dict at this point with torch.cuda.amp.autocast(enabled=self.training_config.fp16): model_output = self.model(prepared_batch) report = Report(prepared_batch, model_output) self.profile("Forward time") return report def _start_update(self): logger.debug(self.num_updates + 1) self.on_update_start() self.optimizer.zero_grad() def _backward(self, loss: Tensor) -> None: self.scaler.scale(loss).backward() self.profile("Backward time") def _finish_update(self): if self.training_config.clip_gradients: clip_gradients( self.model, self.optimizer, self.num_updates, self.logistics_callback.tb_writer, self.config, scale=self.scaler.get_scale(), ) if is_xla(): import torch_xla.core.xla_model as xm # Assumes no model parallel xm.reduce_gradients(self.optimizer) self.scaler.step(self.optimizer) self.scaler.update() self.num_updates += 1 self.profile("Finished update") def _calculate_max_updates(self): config_max_updates = self.training_config.max_updates config_max_epochs = self.training_config.max_epochs max_updates, _ = get_max_updates( config_max_updates, config_max_epochs, self.train_loader, self.training_config.update_frequency, ) return max_updates
def __init__( self, num_train_data, max_updates, max_epochs, config=None, optimizer=None, update_frequency=1, batch_size=1, batch_size_per_device=None, fp16=False, on_update_end_fn=None, scheduler_config=None, grad_clipping_config=None, ): if config is None: self.config = OmegaConf.create( { "training": { "detect_anomaly": False, "evaluation_interval": 10000, "update_frequency": update_frequency, "fp16": fp16, "batch_size": batch_size, "batch_size_per_device": batch_size_per_device, } } ) self.training_config = self.config.training else: self.training_config = config.training self.config = config # Load batch size with custom config and cleanup original_config = registry.get("config") registry.register("config", self.config) batch_size = get_batch_size() registry.register("config", original_config) if max_updates is not None: self.training_config["max_updates"] = max_updates if max_epochs is not None: self.training_config["max_epochs"] = max_epochs self.model = SimpleModel({"in_dim": 1}) self.model.build() if torch.cuda.is_available(): self.model = self.model.cuda() self.device = "cuda" else: self.device = "cpu" self.distributed = False self.dataset_loader = MagicMock() self.dataset_loader.seed_sampler = MagicMock(return_value=None) self.dataset_loader.prepare_batch = lambda x: SampleList(x) if optimizer is None: self.optimizer = MagicMock() self.optimizer.step = MagicMock(return_value=None) self.optimizer.zero_grad = MagicMock(return_value=None) else: self.optimizer = optimizer if scheduler_config: config.training.lr_scheduler = True config.scheduler = scheduler_config self.lr_scheduler_callback = LRSchedulerCallback(config, self) self.callbacks.append(self.lr_scheduler_callback) on_update_end_fn = ( on_update_end_fn if on_update_end_fn else self.lr_scheduler_callback.on_update_end ) if grad_clipping_config: self.training_config.clip_gradients = True self.training_config.max_grad_l2_norm = grad_clipping_config[ "max_grad_l2_norm" ] self.training_config.clip_norm_mode = grad_clipping_config["clip_norm_mode"] dataset = NumbersDataset(num_train_data) self.train_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False, ) self.train_loader.current_dataset = dataset self.on_batch_start = MagicMock(return_value=None) self.on_update_start = MagicMock(return_value=None) self.logistics_callback = MagicMock(return_value=None) self.logistics_callback.log_interval = MagicMock(return_value=None) self.on_batch_end = MagicMock(return_value=None) self.on_update_end = ( on_update_end_fn if on_update_end_fn else MagicMock(return_value=None) ) self.meter = Meter() self.after_training_loop = MagicMock(return_value=None) self.on_validation_start = MagicMock(return_value=None) self.evaluation_loop = MagicMock(return_value=(None, None)) self.scaler = torch.cuda.amp.GradScaler(enabled=False) self.val_loader = MagicMock(return_value=None) self.early_stop_callback = MagicMock(return_value=None) self.on_validation_end = MagicMock(return_value=None) self.metrics = MagicMock(return_value=None)
def __init__( self, num_train_data, max_updates, max_epochs, config=None, optimizer=None, update_frequency=1, batch_size=1, fp16=False, on_update_end_fn=None, ): if config is None: self.training_config = OmegaConf.create({ "detect_anomaly": False, "evaluation_interval": 10000, "update_frequency": update_frequency, "fp16": fp16, "batch_size": batch_size, }) else: self.training_config = config.training if max_updates is not None: self.training_config["max_updates"] = max_updates if max_epochs is not None: self.training_config["max_epochs"] = max_epochs self.model = SimpleModel(1) if torch.cuda.is_available(): self.model = self.model.cuda() self.device = "cuda" else: self.device = "cpu" self.dataset_loader = MagicMock() self.dataset_loader.seed_sampler = MagicMock(return_value=None) self.dataset_loader.prepare_batch = lambda x: SampleList(x) if optimizer is None: self.optimizer = MagicMock() self.optimizer.step = MagicMock(return_value=None) self.optimizer.zero_grad = MagicMock(return_value=None) else: self.optimizer = optimizer dataset = NumbersDataset(num_train_data) self.train_loader = torch.utils.data.DataLoader( dataset=dataset, batch_size=batch_size, shuffle=False, num_workers=1, drop_last=False, ) self.train_loader.current_dataset = dataset self.on_batch_start = MagicMock(return_value=None) self.on_update_start = MagicMock(return_value=None) self.logistics_callback = MagicMock(return_value=None) self.logistics_callback.log_interval = MagicMock(return_value=None) self.on_batch_end = MagicMock(return_value=None) self.on_update_end = (on_update_end_fn if on_update_end_fn else MagicMock(return_value=None)) self.meter = Meter() self.after_training_loop = MagicMock(return_value=None) self.on_validation_start = MagicMock(return_value=None) self.evaluation_loop = MagicMock(return_value=(None, None)) self.scaler = torch.cuda.amp.GradScaler(enabled=False) self.val_loader = MagicMock(return_value=None) self.early_stop_callback = MagicMock(return_value=None) self.on_validation_end = MagicMock(return_value=None) self.metrics = MagicMock(return_value=None)