class Trainer: out: Path out_prefix: str data_loaders: DataLoaderCollection metrics: MetricsCollection train_rate: float = 0.1 max_epochs: int start_epoch: int save_checkpoint_epochs: int loss_metric: Loss optimizer: Optimizer model: Module device: torch.device half: bool ddp: bool test_no_grad: bool checkpoint: Checkpoint best_checkpoint: Checkpoint _training_model: Module _training_optimizer: Optimizer def __init__(self, out_prefix: str, data_loaders: DataLoaderCollection, metrics: MetricsCollection, loss_function: Callable[[torch.Tensor, torch.Tensor], torch.Tensor], model: Module, optimizer: Optimizer, out: Path, train_rate: float, max_epochs: int, save_checkpoint_epochs: int, device: torch.device, half: bool, ddp: bool, test_no_grad: bool = True): self.out_prefix = out_prefix self.data_loaders = data_loaders self.metrics = deepcopy(metrics) self.loss_metric = Loss(loss_function) self.metrics.train_metrics.insert(0, self.loss_metric) self.metrics.test_metrics.insert(0, Loss( loss_function)) # test loss metric should calculate independently. self.model = model self.optimizer = optimizer self.out = out self.train_rate = train_rate self.max_epochs = max_epochs self.save_checkpoint_epochs = save_checkpoint_epochs self.device = device self.half = half self.ddp = ddp self.test_no_grad = test_no_grad self.start_epoch = 0 self.checkpoint = Checkpoint() self.best_checkpoint = Checkpoint() self.best_checkpoint.indicate = inf self._training_model = self.model self._training_optimizer = self.optimizer self.setup_training_model_and_optimizer() def train(self): self.out.mkdir(parents=True, exist_ok=True) for epoch in range(self.start_epoch, self.max_epochs): self.checkpoint.set_epoch(epoch) self.train_epoch(epoch) self.test(self.test_no_grad) torch.cuda.empty_cache() self.checkpoint.add_metrics(self.metrics) self.checkpoint.set_indicate(self.calc_indicate(self.train_rate)) if self.checkpoint.indicate <= self.best_checkpoint.indicate: self.checkpoint.set_states(self.model, self.optimizer) self.best_checkpoint = self.checkpoint.copy() self.save_model(self.out / f'{self.out_prefix}_backup_best.pt', self.best_checkpoint) if (epoch + 1) % self.save_checkpoint_epochs == 0: self.checkpoint.set_states(self.model, self.optimizer) self.save_model( self.out / f'{self.out_prefix}_backup_{epoch + 1}.pt', self.checkpoint) print( f"\r" f"epoch {epoch + 1}/{self.max_epochs} / " f"train({', '.join([str(metric) for metric in self.metrics.train_metrics])}) / " f"test({', '.join([str(metric) for metric in self.metrics.test_metrics])}) / " f"indicate(current={self.checkpoint.indicate:.6f}, best={self.best_checkpoint.indicate:.6f})" ) self.save_model(self.out / f'{self.out_prefix}_best.pt', self.best_checkpoint) self.checkpoint.set_states(self.model, self.optimizer) self.save_model(self.out / f'{self.out_prefix}_last.pt', self.checkpoint) def setup_training_model_and_optimizer(self): self.model.to(self.device) if self.half: from apex import amp # Initialization opt_level = 'O1' self._training_model, self._training_optimizer = \ amp.initialize(self._training_model, self._training_optimizer, opt_level=opt_level, verbosity=0) if self.ddp: # try port number incrementally port = 9999 while True: try: dist.init_process_group( backend='nccl', init_method=f'tcp://127.0.0.1:{port}', world_size=1, rank=0) except RuntimeError as e: if str(e) == "Address already in use": port += 1 continue else: raise e break self._training_model = DistributedDataParallel( self._training_model, find_unused_parameters=True) def save_model(self, path: Path, checkpoint: Checkpoint): real_path = path state_dict = checkpoint.state_dict() torch.save(state_dict, real_path) def calc_indicate(self, train_rate) -> float: test_rate = 1 - train_rate test_loss = self.metrics.test_metrics[0] return self.loss_metric.get_value()**train_rate * test_loss.get_value( )**test_rate def load_checkpoint(self, state_dict: dict): self.checkpoint.load(state_dict) self.start_epoch = self.checkpoint.epoch + 1 self.model.load_state_dict(self.checkpoint.model_state) self.optimizer.load_state_dict(self.checkpoint.optimizer_state) def train_epoch(self, epoch: int): self._training_model.train() self.reset_all_metrics_states(self.metrics.train_metrics) for i, (data, target) in enumerate(self.data_loaders.train_loader): data = data.to(self.device) target = target.to(self.device) pred = self._training_model(data) evaluated = self.evaluate_metrics(pred, target, self.metrics.train_metrics) loss = self.loss_metric.get_tensor() if self.half: from apex import amp with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() self._training_optimizer.step() self._training_optimizer.zero_grad() mem_info = "" if self.device.type != "cpu": torch.cuda.synchronize() mem_info = f" / mem(allocated={torch.cuda.memory_allocated()}, cached={torch.cuda.memory_cached()})" print("\r" f"epoch {epoch + 1}/{self.max_epochs} / " f"batch {i + 1}/{len(self.data_loaders.train_loader)} / " f"train ({self.get_evaluated_metrics_info(evaluated)})" + mem_info, end="") def test(self, no_grad: bool = True, loader: DataLoader = None, metrics: List[MetricBase] = None): self._training_model.eval() grad_policy = torch.no_grad if no_grad else torch.enable_grad loader = loader or self.data_loaders.test_loader metrics = metrics or self.metrics.test_metrics self.reset_all_metrics_states(metrics) with grad_policy(): for i, (data, target) in enumerate(loader): data = data.to(self.device) target = target.to(self.device) pred = self._training_model(data) self.evaluate_metrics(pred, target, metrics) def reset_all_metrics_states(self, metrics: List[MetricBase]): for metric in metrics: metric.reset_states() def evaluate_metrics(self, pred: torch.Tensor, target: torch.Tensor, metrics: List[MetricBase]) -> Dict[str, float]: # The loss should calculate the gradient if gradient is enabled. loss_metric = metrics[0] result = {loss_metric.get_name(): loss_metric.evaluate(pred, target)} # Other metrics shouldn't calculate gradients so let them are detached. pred = pred.detach() target = target.detach() for metric in metrics[1:]: result[metric.get_name()] = metric.evaluate(pred, target) return result def get_evaluated_metrics_info(self, evaluated_infos: Dict[str, float]) -> str: return ', '.join( [f'{name}={info:.6f}' for name, info in evaluated_infos.items()])
def train(): config_file = "configs/train_daily_dialog_topic_config.json" config = Config.from_json_file(config_file) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if config.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", config.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(config)) # Initialize distributed training if needed config.distributed = (config.local_rank != -1) if config.distributed: torch.cuda.set_device(config.local_rank) config.device = torch.device("cuda", config.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer if "gpt2" in config.model_checkpoint else OpenAIGPTTokenizer tokenizer = tokenizer_class.from_pretrained(config.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in config.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(config.model_checkpoint) tokenizer.set_special_tokens(SPECIAL_TOKENS) model.set_num_special_tokens(len(SPECIAL_TOKENS)) model.to(config.device) optimizer = OpenAIAdam(model.parameters(), lr=config.lr) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if config.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=config.fp16) if config.distributed: model = DistributedDataParallel(model, device_ids=[config.local_rank], output_device=config.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( config, tokenizer) # Training function and trainer def update(engine, batch): model.train() input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = tuple( input_tensor.to(config.device) for input_tensor in batch) lm_loss, mc_loss = model(input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids) loss = (lm_loss * config.lm_coef + mc_loss * config.mc_coef) / config.gradient_accumulation_steps if config.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), config.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), config.max_norm) if engine.state.iteration % config.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(config.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids, token_emotion_ids, token_action_ids = batch #logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) model_outputs = model(input_ids, mc_token_ids, token_type_ids=token_type_ids, token_emotion_ids=token_emotion_ids, token_action_ids=token_action_ids) lm_logits, mc_logits = model_outputs[0], model_outputs[ 1] # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if config.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if config.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if config.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, config.lr), (config.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], config), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], config) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if config.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir=config.log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(tb_logger.writer.log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" take care of distributed encapsulation torch.save(config, tb_logger.writer.log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(tb_logger.writer.log_dir, CONFIG_NAME)) tokenizer.save_vocabulary(tb_logger.writer.log_dir) # Run the training trainer.run(train_loader, max_epochs=config.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if config.local_rank in [-1, 0] and config.n_epochs > 0: os.rename( checkpoint_handler._saved[-1][1][-1], os.path.join(tb_logger.writer.log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
def lops(self, rank, data_dict, optimizer, scheduler, percent_thresh=0.5): print("LOPS running..") print("Before LOPS: Input Ids", data_dict["input_ids"].shape) print("Before LOPS: Mask", data_dict["attention_masks"].shape) print("Before LOPS: Labels", data_dict["labels"].shape) model = LOTClassModel.from_pretrained(self.pretrained_lm, output_attentions=False, output_hidden_states=False, num_labels=self.num_class) model = model.to(rank) model = DDP(model, device_ids=[rank], find_unused_parameters=True) model.train() y_pseudo = data_dict["all_target_pred"].numpy() dataset = TensorDataset(data_dict["input_ids"], data_dict["attention_masks"], data_dict["labels"]) sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=rank) train_dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.train_batch_size, shuffle=False) sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=rank, shuffle=False) prediction_dataloader = DataLoader(dataset, sampler=sampler, batch_size=self.eval_batch_size, shuffle=False ) inds_map = {} for i, j in enumerate(y_pseudo): try: inds_map[j].append(i) except: inds_map[j] = [i] thresh_map = dict(Counter(y_pseudo)) print("Counts of pseudo-labels ", thresh_map, flush=True) for i in thresh_map: thresh_map[i] = int(thresh_map[i] * percent_thresh) print("Threshold map ", thresh_map, flush=True) filter_flag_map = {} train_inds_map = {} non_train_inds_map = {} for i in thresh_map: filter_flag_map[i] = False train_inds_map[i] = [] non_train_inds_map[i] = [] total_train_loss = 0 wrap_train_dataset_loader = tqdm(train_dataloader) if rank == 0 else train_dataloader model.zero_grad() try: for j, batch in enumerate(wrap_train_dataset_loader): input_ids = batch[0].to(rank) input_mask = batch[1].to(rank) target_dist = batch[2].to(rank) logits = model(input_ids, pred_mode="classification", token_type_ids=None, attention_mask=input_mask) logits = logits[:, 0, :] preds = nn.LogSoftmax(dim=-1)(logits) loss = self.st_loss(preds.view(-1, self.num_class), target_dist.view(-1, self.num_class)) / self.accum_steps total_train_loss += loss.item() loss.backward() if (j + 1) % self.accum_steps == 0: # Clip the norm of the gradients to 1.0. nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() scheduler.step() model.zero_grad() avg_train_loss = torch.tensor([total_train_loss / len(wrap_train_dataset_loader) * self.accum_steps]).to( rank) gather_list = [torch.ones_like(avg_train_loss) for _ in range(self.world_size)] dist.all_gather(gather_list, avg_train_loss) avg_train_loss = torch.tensor(gather_list) if rank == 0: print(f"lr: {optimizer.param_groups[0]['lr']:.4g}") print(f"Average training loss: {avg_train_loss.mean().item()}") except RuntimeError as err: self.cuda_mem_error(err, "train", rank) input_ids, input_mask, preds = self.inference(model, prediction_dataloader, rank, return_type="data") gather_input_ids = [torch.ones_like(input_ids) for _ in range(self.world_size)] gather_input_mask = [torch.ones_like(input_mask) for _ in range(self.world_size)] gather_preds = [torch.ones_like(preds) for _ in range(self.world_size)] dist.all_gather(gather_input_ids, input_ids) dist.all_gather(gather_input_mask, input_mask) dist.all_gather(gather_preds, preds) all_preds = torch.cat(gather_preds, dim=0).cpu() pred_inds = all_preds.argmax(dim=-1) count = 0 for i in filter_flag_map: if not filter_flag_map[i]: train_inds, non_train_inds = self.compute_train_non_train_inds(pred_inds, y_pseudo, inds_map, i) train_inds_map[i] = train_inds non_train_inds_map[i] = non_train_inds if len(train_inds) >= thresh_map[i]: filter_flag_map[i] = True count += 1 else: count += 1 print("Number of labels reached 50 percent threshold", count) for i in filter_flag_map: if not filter_flag_map[i]: print("For label ", i, " Number expected ", thresh_map[i], " Found ", len(train_inds_map[i])) for i in filter_flag_map: if not filter_flag_map[i]: print("Resetting train, non-train inds for label ", i) train_inds_map[i] = inds_map[i] non_train_inds_map[i] = [] ret_input_ids = [] ret_masks = [] ret_labels = [] for lbl in train_inds_map: for loop_ind in train_inds_map[lbl]: ret_input_ids.append(data_dict["input_ids"][loop_ind].unsqueeze(0)) ret_masks.append(data_dict["attention_masks"][loop_ind].unsqueeze(0)) ret_labels.append(data_dict["labels"][loop_ind].unsqueeze(0)) all_input_ids = torch.cat(ret_input_ids, dim=0) all_input_mask = torch.cat(ret_masks, dim=0) all_labels = torch.cat(ret_labels, dim=0) print("After LOPS: Input Ids", all_input_ids.shape) print("After LOPS: Mask", all_input_mask.shape) print("After LOPS: Labels", all_labels.shape) self_train_dict = {"input_ids": all_input_ids, "attention_masks": all_input_mask, "labels": all_labels} return self_train_dict
def setup_experiment(self, config): """ Configure the experiment for training :param config: Dictionary containing the configuration parameters - distributed: Whether or not to use Pytorch Distributed training - backend: Pytorch Distributed backend ("nccl", "gloo") Default: nccl - world_size: Total number of processes participating - rank: Rank of the current process - data: Dataset path - train_dir: Dataset training data relative path - batch_size: Training batch size - val_dir: Dataset validation data relative path - val_batch_size: Validation batch size - workers: how many data loading processes to use - train_loader_drop_last: Whether to skip last batch if it is smaller than the batch size - num_classes: Limit the dataset size to the given number of classes - model_class: Model class. Must inherit from "torch.nn.Module" - model_args: model model class arguments passed to the constructor - init_batch_norm: Whether or not to Initialize running batch norm mean to 0. - optimizer_class: Optimizer class. Must inherit from "torch.optim.Optimizer" - optimizer_args: Optimizer class class arguments passed to the constructor - batch_norm_weight_decay: Whether or not to apply weight decay to batch norm modules parameters See https://arxiv.org/abs/1807.11205 - bias_weight_decay: Whether or not to apply weight decay to bias parameters - lr_scheduler_class: Learning rate scheduler class. Must inherit from "_LRScheduler" - lr_scheduler_args: Learning rate scheduler class class arguments passed to the constructor - loss_function: Loss function. See "torch.nn.functional" - local_dir: Results path - logdir: Directory generated by Ray Tune for this Trial - epochs: Number of epochs to train - batches_in_epoch: Number of batches per epoch. Useful for debugging - log_timestep_freq: Configures mixins and subclasses that log every timestep to only log every nth timestep (in addition to the final timestep of each epoch). Set to 0 to log only at the end of each epoch. - progress: Show progress during training - name: Experiment name. Used as logger name - log_level: Python Logging level - log_format: Python Logging format - seed: the seed to be used for pytorch, python, and numpy - mixed_precision: Whether or not to enable apex mixed precision - mixed_precision_args: apex mixed precision arguments. See "amp.initialize" - sample_transform: Transform acting on the training samples. To be used additively after default transform or auto-augment. - target_transform: Transform acting on the training targets. - replicas_per_sample: Number of replicas to create per sample in the batch. (each replica is transformed independently) Used in maxup. - train_model_func: Optional user defined function to train the model, expected to behave similarly to `train_model` in terms of input parameters and return values - evaluate_model_func: Optional user defined function to validate the model expected to behave similarly to `evaluate_model` in terms of input parameters and return values - checkpoint_file: if not None, will start from this model. The model must have the same model_args and model_class as the current experiment. - resize_buffers_for_checkpoint: if True, this will resize the model buffers to match those in the checkpoint. This is helpful for loading buffers with sparse levels not matching the model_args - checkpoint_at_init: boolean argument for whether to create a checkpoint of the initialized model. this differs from `checkpoint_at_start` for which the checkpoint occurs after the first epoch of training as opposed to before it - epochs_to_validate: list of epochs to run validate(). A -1 asks to run validate before any training occurs. Default: last three epochs. - extra_validations_per_epoch: number of additional validations to perform mid-epoch. Additional validations are distributed evenly across training batches. - launch_time: time the config was created (via time.time). Used to report wall clock time until the first batch is done. Default: time.time() in this setup_experiment(). """ # Configure logging related stuff log_format = config.get("log_format", logging.BASIC_FORMAT) log_level = getattr(logging, config.get("log_level", "INFO").upper()) console = logging.StreamHandler() console.setFormatter(logging.Formatter(log_format)) self.logger = logging.getLogger(config.get("name", type(self).__name__)) self.logger.setLevel(log_level) self.logger.addHandler(console) self.progress = config.get("progress", False) self.launch_time = config.get("launch_time", time.time()) self.logdir = config.get("logdir", None) # Configure seed self.seed = config.get("seed", self.seed) set_random_seed(self.seed, False) # Configure distribute pytorch self.distributed = config.get("distributed", False) self.rank = config.get("rank", 0) if self.rank == 0: self.logger.info( f"Execution order: {pformat(self.get_execution_order())}") if self.distributed: dist_url = config.get("dist_url", "tcp://127.0.0.1:54321") backend = config.get("backend", "nccl") world_size = config.get("world_size", 1) dist.init_process_group( backend=backend, init_method=dist_url, rank=self.rank, world_size=world_size, ) # Only enable logs from first process self.logger.disabled = self.rank != 0 self.progress = self.progress and self.rank == 0 # Configure model self.device = config.get("device", self.device) self.model = self.create_model(config, self.device) if self.rank == 0: self.logger.debug(self.model) # Configure optimizer group_decay, group_no_decay = [], [] for module in self.model.modules(): for name, param in module.named_parameters(recurse=False): if self.should_decay_parameter(module, name, param, config): group_decay.append(param) else: group_no_decay.append(param) optimizer_class = config.get("optimizer_class", torch.optim.SGD) optimizer_args = config.get("optimizer_args", {}) self.optimizer = optimizer_class([ dict(params=group_decay), dict(params=group_no_decay, weight_decay=0.) ], **optimizer_args) # Validate mixed precision requirements self.mixed_precision = config.get("mixed_precision", False) if self.mixed_precision and amp is None: self.mixed_precision = False self.logger.error( "Mixed precision requires NVIDA APEX." "Please install apex from https://www.github.com/nvidia/apex" "Disabling mixed precision training.") # Configure mixed precision training if self.mixed_precision: amp_args = config.get("mixed_precision_args", {}) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, **amp_args) self.logger.info("Using mixed precision") # Apply DistributedDataParallel after all other model mutations if self.distributed: self.model = DistributedDataParallel(self.model) else: self.model = DataParallel(self.model) self._loss_function = config.get("loss_function", torch.nn.functional.cross_entropy) self.num_classes = config.get("num_classes", 1000) self.epochs = config.get("epochs", 1) self.batches_in_epoch = config.get("batches_in_epoch", sys.maxsize) self.current_epoch = 0 # Get initial batch size self.batch_size = config.get("batch_size", 1) # CUDA runtime does not support the fork start method. # See https://pytorch.org/docs/stable/notes/multiprocessing.html multiprocessing.set_start_method("spawn", force=True) # Configure data loaders self.train_loader = self.create_train_dataloader(config) self.val_loader = self.create_validation_dataloader(config) self.total_batches = len(self.train_loader) self.epochs_to_validate = config.get( "epochs_to_validate", range(self.epochs - 3, self.epochs + 1)) extra_validations = config.get("extra_validations_per_epoch", 0) batches_to_validate = np.linspace( min(self.total_batches, self.batches_in_epoch), 0, 1 + extra_validations, endpoint=False)[::-1].round().astype("int").tolist() self.additional_batches_to_validate = batches_to_validate[:-1] if extra_validations > 0: self.logger.info( f"Extra validations per epoch: {extra_validations}, " f"batch indices: {self.additional_batches_to_validate}") # Used for logging. Conceptually, it is a version number for the model's # parameters. By default, this is the elapsed number of batches that the # model has been trained on. Experiments may also increment this on # other events like model prunings. When validation is performed after a # training batch, the validation results are assigned to the next # timestep after that training batch, since it was performed on the # subsequent version of the parameters. self.current_timestep = 0 self.log_timestep_freq = config.get("log_timestep_freq", 1) # A list of [(timestep, result), ...] for the current epoch. self.extra_val_results = [] # Configure learning rate scheduler lr_scheduler_class = config.get("lr_scheduler_class", None) if lr_scheduler_class is not None: lr_scheduler_args = config.get("lr_scheduler_args", {}) self.logger.info("LR Scheduler args:") self.logger.info(pformat(lr_scheduler_args)) self.logger.info("steps_per_epoch=%s", self.total_batches) self.lr_scheduler = create_lr_scheduler( optimizer=self.optimizer, lr_scheduler_class=lr_scheduler_class, lr_scheduler_args=lr_scheduler_args, steps_per_epoch=self.total_batches) # Set train and validate methods. self.train_model = config.get("train_model_func", train_model) self.evaluate_model = config.get("evaluate_model_func", evaluate_model)
class GradientDescentTrainer(Trainer): def __init__( self, model: Model, optimizer: torch.optim.Optimizer, data_loader: torch.utils.data.DataLoader, patience: Optional[int] = None, validation_metric: str = "-loss", validation_data_loader: torch.utils.data.DataLoader = None, num_epochs: int = 20, serialization_dir: Optional[str] = None, checkpointer: Checkpointer = None, model_save_interval: float = None, cuda_device: int = -1, grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None, learning_rate_scheduler: Optional[LearningRateScheduler] = None, momentum_scheduler: Optional[MomentumScheduler] = None, tensorboard_writer: TensorboardWriter = None, log_batch_size_period: Optional[int] = None, moving_average: Optional[MovingAverage] = None, distributed: bool = False, local_rank: int = 0, world_size: int = 1, num_gradient_accumulation_steps: int = 1, opt_level: Optional[str] = None, ) -> None: """ A trainer for doing supervised learning. It just takes a labeled dataset and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights for your model over some fixed number of epochs. You can also pass in a validation dataloader and enable early stopping. There are many other bells and whistles as well. # Parameters model : `Model`, required. An AllenNLP model to be optimized. Pytorch Modules can also be optimized if their `forward` method returns a dictionary with a "loss" key, containing a scalar tensor representing the loss function to be optimized. If you are training your model using GPUs, your model should already be on the correct device. (If you use `Trainer.from_params` this will be handled for you.) optimizer : `torch.nn.Optimizer`, required. An instance of a Pytorch Optimizer, instantiated with the parameters of the model to be optimized. data_loader : `DataLoader`, required. A pytorch `DataLoader` containing your `Dataset`, yielding padded indexed batches. patience : Optional[int] > 0, optional (default=None) Number of epochs to be patient before early stopping: the training is stopped after `patience` epochs with no improvement. If given, it must be `> 0`. If None, early stopping is disabled. validation_metric : str, optional (default="loss") Validation metric to measure for whether to stop training using patience and whether to serialize an `is_best` model each epoch. The metric name must be prepended with either "+" or "-", which specifies whether the metric is an increasing or decreasing function. validation_dataloader : `DataLoader`, optional (default=None) A `DataLoader` to use for the validation set. If `None`, then use the training `DataLoader` with the validation data. num_epochs : int, optional (default = 20) Number of training epochs. serialization_dir : str, optional (default=None) Path to directory for saving and loading model files. Models will not be saved if this parameter is not passed. checkpointer : `Checkpointer`, optional (default=None) A `Checkpointer` is responsible for periodically saving model weights. If none is given here, we will construct one with default parameters. model_save_interval : `float`, optional (default=None) If provided, then serialize models every `model_save_interval` seconds within single epochs. In all cases, models are also saved at the end of every epoch if `serialization_dir` is provided. cuda_device : `int`, optional (default = -1) An integer specifying the CUDA device(s) to use for this process. If -1, the CPU is used. Data parallelism is controlled at the allennlp train level, so each trainer will have a single GPU. grad_norm : `float`, optional, (default = None). If provided, gradient norms will be rescaled to have a maximum of this value. grad_clipping : `float`, optional (default = `None`). If provided, gradients will be clipped `during the backward pass` to have an (absolute) maximum of this value. If you are getting `NaNs` in your gradients during training that are not solved by using `grad_norm`, you may need this. learning_rate_scheduler : `LearningRateScheduler`, optional (default = None) If specified, the learning rate will be decayed with respect to this schedule at the end of each epoch (or batch, if the scheduler implements the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`, this will use the `validation_metric` provided to determine if learning has plateaued. To support updating the learning rate on every batch, this can optionally implement `step_batch(batch_num_total)` which updates the learning rate given the batch number. momentum_scheduler : `MomentumScheduler`, optional (default = None) If specified, the momentum will be updated at the end of each batch or epoch according to the schedule. tensorboard_writer : `TensorboardWriter`, optional If this is not provided, we will construct a `TensorboardWriter` with default parameters and use that. log_batch_size_period : `int`, optional, (default = `None`) If defined, how often to log the average batch size. moving_average : `MovingAverage`, optional, (default = None) If provided, we will maintain moving averages for all parameters. During training, we employ a shadow variable for each parameter, which maintains the moving average. During evaluation, we backup the original parameters and assign the moving averages to corresponding parameters. Be careful that when saving the checkpoint, we will save the moving averages of parameters. This is necessary because we want the saved model to perform as well as the validated model if we load it later. But this may cause problems if you restart the training from checkpoint. distributed : `bool`, optional, (default = False) If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also requires `world_size` to be greater than 1. local_rank : `int`, optional, (default = 0) This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is used as the rank. world_size : `int`, (default = 1) The number of `Trainer` workers participating in the distributed training. num_gradient_accumulation_steps : `int`, optional, (default = 1) Gradients are accumulated for the given number of steps before doing an optimizer step. This can be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's [post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation. opt_level : `str`, optional, (default = `None`) Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`. See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for more details. If `None`, Amp is not used. Defaults to `None`. """ super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size) # I am not calling move_to_gpu here, because if the model is # not already on the GPU then the optimizer is going to be wrong. self.model = model self.data_loader = data_loader self._validation_data_loader = validation_data_loader self.optimizer = optimizer if patience is None: # no early stopping if validation_data_loader: logger.warning( "You provided a validation dataset but patience was set to None, " "meaning that early stopping is disabled") elif (not isinstance(patience, int)) or patience <= 0: raise ConfigurationError( '{} is an invalid value for "patience": it must be a positive integer ' "or None (if you want to disable early stopping)".format( patience)) # For tracking is_best_so_far and should_stop_early self._metric_tracker = MetricTracker(patience, validation_metric) # Get rid of + or - self._validation_metric = validation_metric[1:] self._num_epochs = num_epochs if checkpointer is not None: self._checkpointer = checkpointer else: self._checkpointer = Checkpointer(serialization_dir) self._model_save_interval = model_save_interval self._grad_norm = grad_norm self._grad_clipping = grad_clipping self._learning_rate_scheduler = learning_rate_scheduler self._momentum_scheduler = momentum_scheduler self._moving_average = moving_average # We keep the total batch number as an instance variable because it # is used inside a closure for the hook which logs activations in # `_enable_activation_logging`. self._batch_num_total = 0 self._tensorboard = tensorboard_writer or TensorboardWriter( serialization_dir) self._tensorboard.get_batch_num_total = lambda: self._batch_num_total self._tensorboard.enable_activation_logging(self.model) self._log_batch_size_period = log_batch_size_period self._last_log = 0.0 # time of last logging self._num_gradient_accumulation_steps = num_gradient_accumulation_steps # Enable automatic mixed precision training with NVIDIA Apex. self._opt_level = opt_level if self._opt_level is not None: if amp is None: raise ConfigurationError(( "Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable" " automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex." )) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=self._opt_level) # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. # # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the # normal case, reference to `Model` is retained. This reference is only used in # these places: `model.__call__`, `model.train` and `model.eval`. if self._distributed: self._pytorch_model = DistributedDataParallel( self.model, device_ids=[self.cuda_device], find_unused_parameters=True) else: self._pytorch_model = self.model def rescale_gradients(self) -> Optional[float]: """ Performs gradient rescaling. Is a no-op if gradient rescaling is not enabled. """ if self._grad_norm: if self._opt_level is not None: # See: https://nvidia.github.io/apex/advanced.html#gradient-clipping parameters_to_clip = [ p for p in amp.master_params(self.optimizer) if p.grad is not None ] else: parameters_to_clip = [ p for p in self.model.parameters() if p.grad is not None ] return training_util.sparse_clip_norm(parameters_to_clip, self._grad_norm) else: return None def batch_loss(self, batch: TensorDict, for_training: bool) -> torch.Tensor: """ Does a forward pass on the given batches and returns the `loss` value in the result. If `for_training` is `True` also applies regularization penalty. """ batch = nn_util.move_to_device(batch, self.cuda_device) output_dict = self._pytorch_model(**batch) try: loss = output_dict["loss"] if for_training: loss += self.model.get_regularization_penalty() except KeyError: if for_training: raise RuntimeError( "The model you are trying to optimize does not contain a" " 'loss' key in the output of model.forward(inputs).") loss = None return loss def _train_epoch(self, epoch: int) -> Dict[str, float]: """ Trains one epoch and returns metrics. """ logger.info("Epoch %d/%d", epoch, self._num_epochs - 1) peak_cpu_usage = common_util.peak_memory_mb() logger.info(f"Peak CPU memory usage MB: {peak_cpu_usage}") gpu_usage = [] for gpu, memory in common_util.gpu_memory_mb().items(): gpu_usage.append((gpu, memory)) logger.info(f"GPU {gpu} memory usage MB: {memory}") train_loss = 0.0 # Set the model to "train" mode. self._pytorch_model.train() # Get tqdm for the training batches batch_generator = iter(self.data_loader) batch_group_generator = common_util.lazy_groups_of( batch_generator, self._num_gradient_accumulation_steps) logger.info("Training") num_training_batches = math.ceil( len(self.data_loader) / self._num_gradient_accumulation_steps) # Having multiple tqdm bars in case of distributed training will be a mess. Hence only the master's # progress is shown if self._master: batch_group_generator_tqdm = Tqdm.tqdm(batch_group_generator, total=num_training_batches) else: batch_group_generator_tqdm = batch_group_generator self._last_log = time.time() last_save_time = time.time() batches_this_epoch = 0 if self._batch_num_total is None: self._batch_num_total = 0 histogram_parameters = set( self.model.get_parameters_for_histogram_tensorboard_logging()) cumulative_batch_group_size = 0 done_early = False for batch_group in batch_group_generator_tqdm: if self._distributed: # Check whether the other workers have stopped already (due to differing amounts of # data in each). If so, we can't proceed because we would hang when we hit the # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor # here because NCCL process groups apparently don't support BoolTensor. done = torch.tensor(0, device=self.cuda_device) torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) if done.item() > 0: done_early = True logger.warning( f"Worker {torch.distributed.get_rank()} finishing training early! " "This implies that there is an imbalance in your training " "data across the workers and that some amount of it will be " "ignored. A small amount of this is fine, but a major imbalance " "should be avoided. Note: This warning will appear unless your " "data is perfectly balanced.") break batches_this_epoch += 1 self._batch_num_total += 1 batch_num_total = self._batch_num_total self.optimizer.zero_grad() for batch in batch_group: loss = self.batch_loss(batch, for_training=True) if torch.isnan(loss): raise ValueError("nan loss encountered") loss = loss / len(batch_group) if self._opt_level is not None: with amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() train_loss += loss.item() batch_grad_norm = self.rescale_gradients() # This does nothing if batch_num_total is None or you are using a # scheduler which doesn't update per batch. if self._learning_rate_scheduler: self._learning_rate_scheduler.step_batch(batch_num_total) if self._momentum_scheduler: self._momentum_scheduler.step_batch(batch_num_total) if self._tensorboard.should_log_histograms_this_batch( ) and self._master: # get the magnitude of parameter updates for logging # We need a copy of current parameters to compute magnitude of updates, # and copy them to CPU so large models won't go OOM on the GPU. param_updates = { name: param.detach().cpu().clone() for name, param in self.model.named_parameters() } self.optimizer.step() for name, param in self.model.named_parameters(): param_updates[name].sub_(param.detach().cpu()) update_norm = torch.norm(param_updates[name].view(-1)) param_norm = torch.norm(param.view(-1)).cpu() self._tensorboard.add_train_scalar( "gradient_update/" + name, update_norm / (param_norm + nn_util.tiny_value_of_dtype(param_norm.dtype)), ) else: self.optimizer.step() # Update moving averages if self._moving_average is not None: self._moving_average.apply(batch_num_total) # Update the description with the latest metrics metrics = training_util.get_metrics( self.model, train_loss, batches_this_epoch, world_size=self._world_size, cuda_device=[self.cuda_device], ) # Updating tqdm only for the master as the trainers wouldn't have one if self._master: description = training_util.description_from_metrics(metrics) batch_group_generator_tqdm.set_description(description, refresh=False) # Log parameter values to Tensorboard (only from the master) if self._tensorboard.should_log_this_batch() and self._master: self._tensorboard.log_parameter_and_gradient_statistics( self.model, batch_grad_norm) self._tensorboard.log_learning_rates(self.model, self.optimizer) self._tensorboard.add_train_scalar("loss/loss_train", metrics["loss"]) self._tensorboard.log_metrics( {"epoch_metrics/" + k: v for k, v in metrics.items()}) if self._tensorboard.should_log_histograms_this_batch( ) and self._master: self._tensorboard.log_histograms(self.model, histogram_parameters) if self._log_batch_size_period: batch_group_size = sum( training_util.get_batch_size(batch) for batch in batch_group) cumulative_batch_group_size += batch_group_size if (batches_this_epoch - 1) % self._log_batch_size_period == 0: average = cumulative_batch_group_size / batches_this_epoch logger.info( f"current batch size: {batch_group_size} mean batch size: {average}" ) self._tensorboard.add_train_scalar("current_batch_size", batch_group_size) self._tensorboard.add_train_scalar("mean_batch_size", average) # Save model if needed. if (self._model_save_interval is not None and (time.time() - last_save_time > self._model_save_interval) and self._master): last_save_time = time.time() self._save_checkpoint("{0}.{1}".format( epoch, training_util.time_to_str(int(last_save_time)))) if self._distributed and not done_early: logger.warning( f"Worker {torch.distributed.get_rank()} completed its entire epoch (training)." ) # Indicate that we're done so that any workers that have remaining data stop the epoch early. done = torch.tensor(1, device=self.cuda_device) torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) assert done.item() # Let all workers finish their epoch before computing # the final statistics for the epoch. if self._distributed: dist.barrier() metrics = training_util.get_metrics( self.model, train_loss, batches_this_epoch, reset=True, world_size=self._world_size, cuda_device=[self.cuda_device], ) metrics["cpu_memory_MB"] = peak_cpu_usage for (gpu_num, memory) in gpu_usage: metrics["gpu_" + str(gpu_num) + "_memory_MB"] = memory return metrics def _validation_loss(self) -> Tuple[float, int]: """ Computes the validation loss. Returns it and the number of batches. """ logger.info("Validating") self._pytorch_model.eval() # Replace parameter values with the shadow values from the moving averages. if self._moving_average is not None: self._moving_average.assign_average_value() if self._validation_data_loader is not None: validation_data_loader = self._validation_data_loader else: raise ConfigurationError( "Validation results cannot be calculated without a validation_data_loader" ) val_generator_tqdm = Tqdm.tqdm(validation_data_loader) batches_this_epoch = 0 val_loss = 0 done_early = False for batch in val_generator_tqdm: if self._distributed: # Check whether the other workers have stopped already (due to differing amounts of # data in each). If so, we can't proceed because we would hang when we hit the # barrier implicit in Model.forward. We use a IntTensor instead a BoolTensor # here because NCCL process groups apparently don't support BoolTensor. done = torch.tensor(0, device=self.cuda_device) torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) if done.item() > 0: done_early = True logger.warning( f"Worker {torch.distributed.get_rank()} finishing validation early! " "This implies that there is an imbalance in your validation " "data across the workers and that some amount of it will be " "ignored. A small amount of this is fine, but a major imbalance " "should be avoided. Note: This warning will appear unless your " "data is perfectly balanced.") break loss = self.batch_loss(batch, for_training=False) if loss is not None: # You shouldn't necessarily have to compute a loss for validation, so we allow for # `loss` to be None. We need to be careful, though - `batches_this_epoch` is # currently only used as the divisor for the loss function, so we can safely only # count those batches for which we actually have a loss. If this variable ever # gets used for something else, we might need to change things around a bit. batches_this_epoch += 1 val_loss += loss.detach().cpu().numpy() # Update the description with the latest metrics val_metrics = training_util.get_metrics( self.model, val_loss, batches_this_epoch, world_size=self._world_size, cuda_device=[self.cuda_device], ) description = training_util.description_from_metrics(val_metrics) val_generator_tqdm.set_description(description, refresh=False) if self._distributed and not done_early: logger.warning( f"Worker {torch.distributed.get_rank()} completed its entire epoch (validation)." ) # Indicate that we're done so that any workers that have remaining data stop validation early. done = torch.tensor(1, device=self.cuda_device) torch.distributed.all_reduce(done, torch.distributed.ReduceOp.SUM) assert done.item() # Now restore the original parameter values. if self._moving_average is not None: self._moving_average.restore() return val_loss, batches_this_epoch def train(self) -> Dict[str, Any]: """ Trains the supplied model with the supplied parameters. """ try: epoch_counter = self._restore_checkpoint() except RuntimeError: traceback.print_exc() raise ConfigurationError( "Could not recover training from the checkpoint. Did you mean to output to " "a different serialization directory or delete the existing serialization " "directory?") training_util.enable_gradient_clipping(self.model, self._grad_clipping) logger.info("Beginning training.") val_metrics: Dict[str, float] = {} this_epoch_val_metric: float = None metrics: Dict[str, Any] = {} epochs_trained = 0 training_start_time = time.time() metrics["best_epoch"] = self._metric_tracker.best_epoch for key, value in self._metric_tracker.best_epoch_metrics.items(): metrics["best_validation_" + key] = value for epoch in range(epoch_counter, self._num_epochs): epoch_start_time = time.time() train_metrics = self._train_epoch(epoch) # get peak of memory usage if "cpu_memory_MB" in train_metrics: metrics["peak_cpu_memory_MB"] = max( metrics.get("peak_cpu_memory_MB", 0), train_metrics["cpu_memory_MB"]) for key, value in train_metrics.items(): if key.startswith("gpu_"): metrics["peak_" + key] = max(metrics.get("peak_" + key, 0), value) if self._validation_data_loader is not None: with torch.no_grad(): # We have a validation set, so compute all the metrics on it. val_loss, num_batches = self._validation_loss() # It is safe again to wait till the validation is done. This is # important to get the metrics right. if self._distributed: dist.barrier() val_metrics = training_util.get_metrics( self.model, val_loss, num_batches, reset=True, world_size=self._world_size, cuda_device=[self.cuda_device], ) # Check validation metric for early stopping this_epoch_val_metric = val_metrics[ self._validation_metric] self._metric_tracker.add_metric(this_epoch_val_metric) if self._metric_tracker.should_stop_early(): logger.info("Ran out of patience. Stopping training.") break if self._master: self._tensorboard.log_metrics( train_metrics, val_metrics=val_metrics, log_to_console=True, epoch=epoch + 1) # +1 because tensorboard doesn't like 0 # Create overall metrics dict training_elapsed_time = time.time() - training_start_time metrics["training_duration"] = str( datetime.timedelta(seconds=training_elapsed_time)) metrics["training_start_epoch"] = epoch_counter metrics["training_epochs"] = epochs_trained metrics["epoch"] = epoch for key, value in train_metrics.items(): metrics["training_" + key] = value for key, value in val_metrics.items(): metrics["validation_" + key] = value if self._metric_tracker.is_best_so_far(): # Update all the best_ metrics. # (Otherwise they just stay the same as they were.) metrics["best_epoch"] = epoch for key, value in val_metrics.items(): metrics["best_validation_" + key] = value self._metric_tracker.best_epoch_metrics = val_metrics if self._serialization_dir and self._master: common_util.dump_metrics( os.path.join(self._serialization_dir, f"metrics_epoch_{epoch}.json"), metrics) # The Scheduler API is agnostic to whether your schedule requires a validation metric - # if it doesn't, the validation metric passed here is ignored. if self._learning_rate_scheduler: self._learning_rate_scheduler.step(this_epoch_val_metric) if self._momentum_scheduler: self._momentum_scheduler.step(this_epoch_val_metric) if self._master: self._save_checkpoint(epoch) # Wait for the master to finish saving the checkpoint if self._distributed: dist.barrier() epoch_elapsed_time = time.time() - epoch_start_time logger.info("Epoch duration: %s", datetime.timedelta(seconds=epoch_elapsed_time)) if epoch < self._num_epochs - 1: training_elapsed_time = time.time() - training_start_time estimated_time_remaining = training_elapsed_time * ( (self._num_epochs - epoch_counter) / float(epoch - epoch_counter + 1) - 1) formatted_time = str( datetime.timedelta(seconds=int(estimated_time_remaining))) logger.info("Estimated training time remaining: %s", formatted_time) epochs_trained += 1 # make sure pending events are flushed to disk and files are closed properly self._tensorboard.close() # Load the best model state before returning best_model_state = self._checkpointer.best_model_state() if best_model_state: self.model.load_state_dict(best_model_state) return metrics def _save_checkpoint(self, epoch: Union[int, str]) -> None: """ Saves a checkpoint of the model to self._serialization_dir. Is a no-op if self._serialization_dir is None. # Parameters epoch : Union[int, str], required. The epoch of training. If the checkpoint is saved in the middle of an epoch, the parameter is a string with the epoch and timestamp. """ # If moving averages are used for parameters, we save # the moving average values into checkpoint, instead of the current values. if self._moving_average is not None: self._moving_average.assign_average_value() # These are the training states we need to persist. training_states = { "metric_tracker": self._metric_tracker.state_dict(), "optimizer": self.optimizer.state_dict(), "batch_num_total": self._batch_num_total, } # If we have a learning rate or momentum scheduler, we should persist them too. if self._learning_rate_scheduler is not None: training_states[ "learning_rate_scheduler"] = self._learning_rate_scheduler.state_dict( ) if self._momentum_scheduler is not None: training_states[ "momentum_scheduler"] = self._momentum_scheduler.state_dict() self._checkpointer.save_checkpoint( model_state=self.model.state_dict(), epoch=epoch, training_states=training_states, is_best_so_far=self._metric_tracker.is_best_so_far(), ) # Restore the original values for parameters so that training will not be affected. if self._moving_average is not None: self._moving_average.restore() def _restore_checkpoint(self) -> int: """ Restores the model and training state from the last saved checkpoint. This includes an epoch count and optimizer state, which is serialized separately from model parameters. This function should only be used to continue training - if you wish to load a model for inference/load parts of a model into a new computation graph, you should use the native Pytorch functions: ` model.load_state_dict(torch.load("/path/to/model/weights.th"))` If `self._serialization_dir` does not exist or does not contain any checkpointed weights, this function will do nothing and return 0. # Returns epoch: int The epoch at which to resume training, which should be one after the epoch in the saved training state. """ model_state, training_state = self._checkpointer.restore_checkpoint() if not training_state: # No checkpoint to restore, start at 0 return 0 self.model.load_state_dict(model_state) self.optimizer.load_state_dict(training_state["optimizer"]) if (self._learning_rate_scheduler is not None and "learning_rate_scheduler" in training_state): self._learning_rate_scheduler.load_state_dict( training_state["learning_rate_scheduler"]) if self._momentum_scheduler is not None and "momentum_scheduler" in training_state: self._momentum_scheduler.load_state_dict( training_state["momentum_scheduler"]) training_util.move_optimizer_to_cuda(self.optimizer) # Currently the `training_state` contains a serialized `MetricTracker`. if "metric_tracker" in training_state: self._metric_tracker.load_state_dict( training_state["metric_tracker"]) # It used to be the case that we tracked `val_metric_per_epoch`. elif "val_metric_per_epoch" in training_state: self._metric_tracker.clear() self._metric_tracker.add_metrics( training_state["val_metric_per_epoch"]) # And before that we didn't track anything. else: self._metric_tracker.clear() if isinstance(training_state["epoch"], int): epoch_to_return = training_state["epoch"] + 1 else: epoch_to_return = int(training_state["epoch"].split(".")[0]) + 1 # For older checkpoints with batch_num_total missing, default to old behavior where # it is unchanged. batch_num_total = training_state.get("batch_num_total") if batch_num_total is not None: self._batch_num_total = batch_num_total return epoch_to_return @classmethod def from_partial_objects( cls, model: Model, serialization_dir: str, data_loader: DataLoader, validation_data_loader: DataLoader = None, local_rank: int = 0, patience: int = None, validation_metric: str = "-loss", num_epochs: int = 20, cuda_device: int = -1, grad_norm: float = None, grad_clipping: float = None, model_save_interval: float = None, log_batch_size_period: int = None, distributed: bool = None, world_size: int = 1, num_gradient_accumulation_steps: int = 1, opt_level: Optional[str] = None, no_grad: List[str] = None, optimizer: Lazy[Optimizer] = None, learning_rate_scheduler: Lazy[LearningRateScheduler] = None, momentum_scheduler: Lazy[MomentumScheduler] = None, tensorboard_writer: Lazy[TensorboardWriter] = None, moving_average: Lazy[MovingAverage] = None, checkpointer: Lazy[Checkpointer] = None, ) -> "Trainer": """ This method exists so that we can have a documented method to construct this class using `FromParams`. If you are not using `FromParams` or config files, you can safely ignore this method. The reason we can't just use `__init__` with `FromParams` here is because there are sequential dependencies to this class's arguments. Anything that has a `Lazy[]` type annotation needs something from one of the non-`Lazy` arguments. The `Optimizer` needs to have the parameters from the `Model` before it's constructed, and the `Schedulers` need to have the `Optimizer`. Because of this, the typical way we construct things `FromParams` doesn't work, so we use `Lazy` to allow for constructing the objects sequentially. If you're not using `FromParams`, you can just construct these arguments in the right order yourself in your code and call the constructor directly. """ check_for_gpu(cuda_device) if cuda_device >= 0: # Moving model to GPU here so that the optimizer state gets constructed on # the right device. model = model.cuda(cuda_device) if no_grad: for name, parameter in model.named_parameters(): if any(re.search(regex, name) for regex in no_grad): parameter.requires_grad_(False) common_util.log_frozen_and_tunable_parameter_names(model) parameters = [[n, p] for n, p in model.named_parameters() if p.requires_grad] optimizer_ = optimizer.construct(model_parameters=parameters) if not optimizer_: optimizer_ = Optimizer.default(parameters) try: batches_per_epoch = len(data_loader) except TypeError: # If the dataset is lazy, it won't have a length. batches_per_epoch = None moving_average_ = moving_average.construct(parameters=parameters) learning_rate_scheduler_ = learning_rate_scheduler.construct( optimizer=optimizer_, num_epochs=num_epochs, num_steps_per_epoch=batches_per_epoch) momentum_scheduler_ = momentum_scheduler.construct( optimizer=optimizer_) checkpointer_ = checkpointer.construct() or Checkpointer( serialization_dir) tensorboard_writer_ = tensorboard_writer.construct( ) or TensorboardWriter(serialization_dir) return cls( model, optimizer_, data_loader, patience=patience, validation_metric=validation_metric, validation_data_loader=validation_data_loader, num_epochs=num_epochs, serialization_dir=serialization_dir, cuda_device=cuda_device, grad_norm=grad_norm, grad_clipping=grad_clipping, learning_rate_scheduler=learning_rate_scheduler_, momentum_scheduler=momentum_scheduler_, tensorboard_writer=tensorboard_writer_, checkpointer=checkpointer_, model_save_interval=model_save_interval, log_batch_size_period=log_batch_size_period, moving_average=moving_average_, distributed=distributed, local_rank=local_rank, world_size=world_size, num_gradient_accumulation_steps=num_gradient_accumulation_steps, opt_level=opt_level, )
def initialize(self, training=True, force_load_plans=False): """ this is a copy of nnUNetTrainerV2's initialize. We only add the regions to the data augmentation :param training: :param force_load_plans: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: if self.local_rank == 0: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: # we need to wait until worker 0 has finished unpacking npz_files = subfiles(self.folder_with_preprocessed_data, suffix=".npz", join=False) case_ids = [i[:-4] for i in npz_files] all_present = all( [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids]) while not all_present: print("worker", self.local_rank, "is waiting for unpacking") sleep(3) all_present = all( [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids]) # there is some slight chance that there may arise some error because dataloader are loading a file # that is still being written by worker 0. We ignore this for now an address it only if it becomes # relevant # (this can occur because while worker 0 writes the file is technically present so the other workers # will proceed and eventually try to read it) else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") # setting weights for deep supervision losses net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)]) weights[~mask] = 0 weights = weights / weights.sum() self.ds_loss_weights = weights seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads')) seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) print("seeds train", seeds_train) print("seeds_val", seeds_val) self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, seeds_train=seeds_train, seeds_val=seeds_val, pin_memory=self.pin_memory, regions=self.regions) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() self._maybe_init_amp() self.network = DDP(self.network, self.local_rank) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True
def run_training(rank, args, hp, port=None): if args.n_gpus > 1: init_distributed(rank, args.n_gpus, port) torch.cuda.set_device(f'cuda:{rank}') ## NOTE: variable model = TransformerWav2vec2( hp, pretrain_model='facebook/wav2vec2-large-lv60', freeze_feature_extractor=hp.freeze_feature_extractor) ## TODO: change init_weight (maybe initialize all networks) #model.apply(init_weight) model.train() if rank == 0: print(model) model = model.to(rank) if args.n_gpus > 1: model = DDP(torch.nn.SyncBatchNorm.convert_sync_batchnorm(model), device_ids=[rank]) max_lr = hp.init_lr if hp.optimizer_type == 'Noam': ## NOTE: scheduling? ## NOTE: learning rate? optimizer = torch.optim.Adam(model.parameters(), lr=max_lr, betas=(0.9, 0.98), eps=1e-9) else: optimizer = torch.optim.Adam(model.parameters(), lr=max_lr) assert (hp.batch_size is None) != (hp.max_seqlen is None) if args.n_gpus > 1: dist.barrier() # configure map_location properly map_location = {'cuda:%d' % 0: 'cuda:%d' % rank} if hp.loaded_epoch is not None: start_epoch = hp.loaded_epoch load_dir = hp.loaded_dir print('epoch {} loaded'.format(hp.loaded_epoch)) loaded_dict = load_model("{}".format( os.path.join(load_dir, 'network.epoch{}'.format(hp.loaded_epoch))), map_location=map_location) model.load_state_dict(loaded_dict) if hp.is_flat_start: step = 1 start_epoch = 0 print('flat_start') else: loaded_dict = torch.load("{}".format( os.path.join( load_dir, 'network.optimizer.epoch{}'.format(hp.loaded_epoch))), map_location=map_location) optimizer.load_state_dict(loaded_dict) step = loaded_dict['state'][0]['step'] #lr = get_learning_rate(step//hp.accum_grad+1, hp) lr = get_learning_rate_tristage(step // hp.accum_grad + 1) for param_group in optimizer.param_groups: param_group['lr'] = lr del loaded_dict torch.cuda.empty_cache() else: start_epoch = 0 step = 1 pytorch_total_params = sum(p.numel() for p in model.parameters()) print('params = {0:.2f}M'.format(pytorch_total_params / 1000 / 1000)) train_epoch(model, optimizer, args, hp, step=step, start_epoch=start_epoch, rank=rank)
def train(): parser = ArgumentParser() parser.add_argument("--basedir", type=str) parser.add_argument("--train_file", type=str, required=True, help='File path to use for train file') parser.add_argument("--valid_file", type=str, required=True, help='File path to use for valid file') parser.add_argument("--dataset_key", default="paired", help="dataset key for basedir") parser.add_argument( "--embed_type", type=str, default='default', choices=["default", "positional", "learned-positional"], help="register label of the embeddings") parser.add_argument("--d_model", type=int, default=512, help="Model dimension (and embedding dsz)") parser.add_argument("--d_ff", type=int, default=2048, help="FFN dimension") parser.add_argument( "--d_k", type=int, default=None, help="Dimension per head. Use if num_heads=1 to reduce dims") parser.add_argument("--num_heads", type=int, default=8, help="Number of heads") parser.add_argument("--num_layers", type=int, default=8, help="Number of layers") parser.add_argument("--windowed_ra", type=str2bool, default=False, help="whether prevent attention beyond rpr_k") parser.add_argument("--num_train_workers", type=int, default=4, help="Number train workers") parser.add_argument("--nctx", type=int, default=256, help="Max input length") parser.add_argument("--tgt_nctx", type=int, help="Max output length, default to args.nctx") parser.add_argument("--file_type", default='json', help="Suffix for data") parser.add_argument("--record_keys", default=['x', 'y'], nargs='+') parser.add_argument("--batch_size", type=int, default=256, help="Batch Size") parser.add_argument("--subword_model_file", type=str, help="The BPE model file", required=True) parser.add_argument("--subword_vocab_file", type=str, help="The BPE subword vocab", required=True) parser.add_argument("--dropout", type=float, default=0.1, help="Dropout") parser.add_argument("--lr_scheduler", type=str, default='cosine', help="The type of learning rate decay scheduler") parser.add_argument("--lr_decay_steps", type=int, help="decay steps of lr scheduler") parser.add_argument("--lr_decay_rate", type=float, help="decay rate of lr scheduler") parser.add_argument("--lr_alpha", type=float, help="parameter alpha for cosine decay scheduler") parser.add_argument("--optim", default="adamw", type=str, help="Optimizer to use (defaults to adamw)") parser.add_argument("--lr", type=float, default=4.0e-4, help="Learning rate") parser.add_argument("--clip", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--weight_decay", type=float, default=1.0e-2, help="Weight decay") parser.add_argument("--epochs", type=int, default=32, help="Num training epochs") parser.add_argument( "--restart_from", type=str, help="Option allows you to restart from a previous checkpoint") parser.add_argument( "--restart_tt", type=str, help="Optional param for legacy checkpoints (step|epoch)") parser.add_argument("--warmup_steps", type=int, default=10000, help="Num warmup steps") parser.add_argument("--saves_per_epoch", type=int, default=10, help="The number of checkpoints to save per epoch") parser.add_argument("--reduction_d_k", type=int, default=64, help="Dimensions of Key and Query in the single headed" "reduction layers") parser.add_argument( "--reduction_type", type=str, default="2ha", help="If using a dual encoder, specifies the reduction type") parser.add_argument( "--unfreeze_after_step", default=0, type=int, help= "Unfreeze encoders after step, ignored if we dont have a checkpoint") parser.add_argument( "--stacking_layers", type=int, nargs='+', default=[], help="Hidden sizes of the dense stack (ff2 from the convert paper)") parser.add_argument("--layer_drop", type=float, default=0.0, help="LayerDrop to apply") parser.add_argument("--ff_pdrop", type=float, default=0.1, help="Dropout in the dense stack") parser.add_argument("--reader_type", type=str, default='preprocessed', choices=['ntp', 'nsp', 'preprocessed', 'tfrecord']) parser.add_argument( "--model_type", default="dual-encoder", choices=["dual-encoder", "encoder-decoder", "transformer-bow"]) parser.add_argument("--src_begin_tok", type=str, nargs='+', default=[]) parser.add_argument("--src_end_tok", type=str, nargs='+', default=['<EOS>']) parser.add_argument("--tgt_begin_tok", type=str, nargs='+', default=['<GO>']) parser.add_argument("--tgt_end_tok", type=str, nargs='+', default=['<EOS>']) parser.add_argument('--lower', type=baseline.str2bool, default=False) parser.add_argument( "--loss", type=str, default='symmetric', choices=['triplet', 'all', 'all_mean', 'contrastive', 'symmetric']) parser.add_argument( "--learn_temp", type=str2bool, default=True, help= "If 'constrastive' or 'symmetric' loss, should we learn the temperature scaling" ) parser.add_argument( "--init_temp", type=float, help="Initialize the temperature for 'contrastive' or 'symmetric' loss" ) parser.add_argument( '--rpr_k', help= 'Relative attention positional sizes pass 0 if you dont want relative attention', type=int, default=[8], nargs='+') parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument("--distributed", type=str2bool, default=False, help="Are we doing distributed training?") parser.add_argument( "--local_rank", type=int, default=-1, help= "Local rank for distributed training (-1 means use the environment variables to find)" ) parser.add_argument("--save_npz", type=str2bool, default=False, help="Whether save npz checkpoint") args = parser.parse_args() if args.basedir is None: args.basedir = '{}-{}-paired-{}-bpe-{}'.format(args.model_type, args.reader_type, args.dataset_key, os.getpid()) logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) num_gpus = get_num_gpus_multiworker() args.distributed = args.distributed or num_gpus > 1 logger.info(f"Using {num_gpus} GPUs in this job.") if args.distributed: args.device, updated_local_rank = init_distributed(args.local_rank) args.local_rank = updated_local_rank if not args.tgt_nctx: args.tgt_nctx = args.nctx reader = MultiFileDatasetReader(args.nctx, args.tgt_nctx, args.src_begin_tok, args.src_end_tok, args.tgt_begin_tok, args.tgt_end_tok, args.subword_model_file, args.subword_vocab_file, args.file_type, reader_type=args.reader_type, record_keys=args.record_keys, lower=args.lower) vocab = reader.build_vocab() # If we are not using chars, then use 'x' for both input and output preproc_data = baseline.embeddings.load_embeddings( 'x', dsz=args.d_model, known_vocab=vocab['x'], preserve_vocab_indices=True, embed_type=args.embed_type) vocabs = preproc_data['vocab'] os.makedirs(args.basedir, exist_ok=True) # We want to make sure to save our input vocab into the basedir for reuse later write_json(vocabs, os.path.join(args.basedir, 'vocabs.json')) embeddings = preproc_data['embeddings'] logger.info("Loaded embeddings") train_set = reader.load(args.train_file, vocabs) valid_set = reader.load(args.valid_file, vocabs, distribute=False, shuffle=False) train_loader = DataLoader(train_set, batch_size=args.batch_size, num_workers=args.num_train_workers) valid_loader = DataLoader(valid_set, batch_size=args.batch_size) logger.info("Loaded datasets") logger.info("Using embedding type [%s]", args.embed_type) if len(args.rpr_k) == 0 or args.rpr_k[0] < 1: rpr_k = None elif len(args.rpr_k) == 1: rpr_k = args.rpr_k[0] else: rpr_k = args.rpr_k model = create_model(embeddings, d_model=args.d_model, d_ff=args.d_ff, dropout=args.dropout, num_heads=args.num_heads, num_layers=args.num_layers, model_type=args.model_type, rpr_k=rpr_k, d_k=args.d_k, reduction_d_k=args.reduction_d_k, stacking_layers=args.stacking_layers, ff_pdrop=args.ff_pdrop, windowed_ra=args.windowed_ra, reduction_type=args.reduction_type, layer_drop=args.layer_drop, logger=logger) model.to(args.device) if args.model_type == 'encoder-decoder': run_step = run_step_s2s else: run_step = run_step_dual logger.info( f"Creating {args.loss}, init temperature: {args.init_temp}, learnable: {args.learn_temp}" ) loss_function = model.create_loss(loss_type=args.loss, init_temp=args.init_temp, learn_temp=args.learn_temp) loss_function.to(args.device) logger.info("Created model and loss") steps_per_epoch = len(train_loader) // num_gpus valid_steps = len(valid_loader) update_on = steps_per_epoch // args.saves_per_epoch report_on = max(10, update_on) // 10 logger.info( f"Steps per epoch per GPU: {steps_per_epoch}. Saving checkpoint every {update_on} steps." ) lr_decay = get_lr_decay(args.lr_scheduler, args.lr, steps_per_epoch, args.epochs, logger, decay_steps=args.lr_decay_steps, decay_rate=args.lr_decay_rate, alpha=args.lr_alpha) linear_warmup = WarmupLinearSchedulerPyTorch(args.warmup_steps, lr=args.lr) lr_sched = CompositeLRScheduler(linear_warmup, lr_decay, lr=args.lr) global_step = 0 start_epoch = 0 if args.restart_from: if args.unfreeze_after_step > 0 and args.model_type == "dual-encoder": logger.info(f"Encoders will be frozen until step %d", args.unfreeze_after_step) global_step, start_epoch = reload_from_checkpoint( args.model_type, args.restart_from, args.restart_tt, model, steps_per_epoch) logger.info( "Restarting from a previous checkpoint %s.\n\tStarting at global_step=%d, epoch=%d", args.restart_from, global_step, start_epoch + 1) target = model if args.model_type == 'encoder-decoder' else loss_function optimizer = OptimizerManager(target, global_step, optim=args.optim, lr=args.lr, lr_function=lr_sched, weight_decay=args.weight_decay) logger.info("Model has {:,} parameters".format( sum(p.numel() for p in target.parameters() if p.requires_grad))) # Prepare model for distributed training if needed if args.distributed: model = DistributedDataParallel(model, device_ids=[args.device], output_device=args.device) logger.info("Model located on %d", args.local_rank) model_base = os.path.join(args.basedir, 'checkpoint') steps = global_step timer = Timer() for epoch in range(start_epoch, args.epochs): avg_loss = Average('average_train_loss') metrics = {} optimizer.zero_grad() timer.start() model.train() train_itr = iter(train_loader) for i in range(steps_per_epoch): batch = next(train_itr) if steps > args.unfreeze_after_step and hasattr( model, 'freeze') and model.freeze: logging.info("Unfreezing encoders at step %d", steps) model.freeze = False steps += 1 x, y = batch loss = run_step(x, y, model, loss_function, args.device, args.distributed) loss.backward() avg_loss.update(loss.item()) torch.nn.utils.clip_grad_norm_(model.parameters(), args.clip) optimizer.step() optimizer.zero_grad() if (i + 1) % report_on == 0: logging.info(avg_loss) if (i + 1) % update_on == 0 and args.local_rank < 1: elapsed = timer.elapsed(True) logging.info('elapsed time this epoch %d min', elapsed) logging.info('elapsed step time %f steps/min', i / elapsed) logging.info('LR: %f', optimizer.current_lr) save_checkpoint(model, model_base, steps, tick_type='step', save_npz=args.save_npz) # How much time elapsed in minutes elapsed = timer.elapsed(True) train_avg_loss = avg_loss.avg # This is the average training token-level loss across all machines # This is the token-level training perplexity metrics['train_elapsed_min'] = elapsed metrics['average_train_loss'] = train_avg_loss if args.local_rank < 1: avg_valid_loss = Average('average_valid_loss') timer.start() model.eval() valid_itr = iter(valid_loader) for j in range(valid_steps): with torch.no_grad(): batch = next(valid_itr) x, y = batch loss = run_step(x, y, model, loss_function, args.device, args.distributed) avg_valid_loss.update(loss.item()) valid_avg_loss = avg_valid_loss.avg elapsed = timer.elapsed(True) metrics['valid_elapsed_min'] = elapsed metrics['average_valid_loss'] = valid_avg_loss logger.info(metrics) save_checkpoint(model, model_base, epoch, tick_type='epoch', save_npz=args.save_npz)
def train(): parser = ArgumentParser() parser.add_argument( "--dataset_path", type=str, default="", help="Path or url of the dataset. If empty download from S3.") parser.add_argument("--dataset_cache", type=str, default='./dataset_cache', help="Path or url of the dataset cache") parser.add_argument("--model_checkpoint", type=str, default="openai-gpt", help="Path, url or short name of the model") parser.add_argument("--num_candidates", type=int, default=2, help="Number of candidates for training") parser.add_argument("--max_history", type=int, default=2, help="Number of previous exchanges to keep in history") parser.add_argument("--train_batch_size", type=int, default=4, help="Batch size for training") parser.add_argument("--valid_batch_size", type=int, default=4, help="Batch size for validation") parser.add_argument("--gradient_accumulation_steps", type=int, default=8, help="Accumulate gradients on several steps") parser.add_argument("--lr", type=float, default=6.25e-5, help="Learning rate") parser.add_argument("--lm_coef", type=float, default=1.0, help="LM loss coefficient") parser.add_argument("--mc_coef", type=float, default=1.0, help="Multiple-choice loss coefficient") parser.add_argument("--max_norm", type=float, default=1.0, help="Clipping gradient norm") parser.add_argument("--n_epochs", type=int, default=3, help="Number of training epochs") parser.add_argument("--personality_permutations", type=int, default=1, help="Number of permutations of personality sentences") parser.add_argument( "--eval_before_start", action='store_true', help="If true start with a first evaluation before training") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)") parser.add_argument( "--fp16", type=str, default="", help= "Set to O0, O1, O2 or O3 for fp16 training (see apex documentation)") parser.add_argument( "--local_rank", type=int, default=-1, help="Local rank for distributed training (-1: not distributed)") parser.add_argument( "--data_faiss", type=str, default="data_persona_faiss_fase1_opcion4", help= "list of the personalities selected with faiss according to the strategy selected" ) args = parser.parse_args() # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) logger.warning( "Running process %d", args.local_rank ) # This is a logger.warning: it will be printed by all distributed processes logger.info("Arguments: %s", pformat(args)) # Initialize distributed training if needed args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info("Prepare tokenizer, pretrained model and optimizer.") tokenizer_class = GPT2Tokenizer if "gpt2" in args.model_checkpoint else OpenAIGPTTokenizer # cant use Autotokenizer because checkpoint could be a Path tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = GPT2DoubleHeadsModel if "gpt2" in args.model_checkpoint else OpenAIGPTDoubleHeadsModel model = model_class.from_pretrained(args.model_checkpoint) model.to(args.device) # Add special tokens if they are not already added add_special_tokens_(model, tokenizer) optimizer = AdamW(model.parameters(), lr=args.lr, correct_bias=True) # Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) logger.info("Prepare datasets") train_loader, val_loader, train_sampler, valid_sampler = get_data_loaders( args, tokenizer) # Training function and trainer # Training function and trainer def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch (lm_loss), (mc_loss), *_ = model(input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, mc_labels=mc_labels, lm_labels=lm_labels) loss = (lm_loss * args.lm_coef + mc_loss * args.mc_coef) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() trainer = Engine(update) # Evaluation function and evaluator (evaluator output is the input of the metrics) def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, mc_token_ids, lm_labels, mc_labels, token_type_ids = batch logger.info(tokenizer.decode(input_ids[0, -1, :].tolist())) # if we dont send labels to model, it doesnt return losses lm_logits, mc_logits, *_ = model( input_ids, token_type_ids=token_type_ids, mc_token_ids=mc_token_ids, ) lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return (lm_logits_flat_shifted, mc_logits), (lm_labels_flat_shifted, mc_labels) evaluator = Engine(inference) # Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Make sure distributed data samplers split the dataset nicely between the distributed processes if args.distributed: trainer.add_event_handler( Events.EPOCH_STARTED, lambda engine: train_sampler.set_epoch(engine.state.epoch)) evaluator.add_event_handler( Events.EPOCH_STARTED, lambda engine: valid_sampler.set_epoch(engine.state.epoch)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-100), output_transform=lambda x: (x[0][0], x[1][0])), "accuracy": Accuracy(output_transform=lambda x: (x[0][1], x[1][1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args), "average_accuracy": MetricsLambda(average_distributed_scalar, metrics["accuracy"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) log_dir = make_logdir(args.model_checkpoint) tb_logger = TensorboardLogger(log_dir) tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach( evaluator, log_handler=OutputHandler( tag="validation", metric_names=list(metrics.keys()), global_step_transform=global_step_from_engine(trainer)), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(log_dir, 'checkpoint', save_interval=1, n_saved=3) trainer.add_event_handler( Events.EPOCH_COMPLETED, checkpoint_handler, {'mymodel': getattr(model, 'module', model) }) # "getattr" takes care of distributed encapsulation torch.save(args, log_dir + '/model_training_args.bin') getattr(model, 'module', model).config.to_json_file(os.path.join(log_dir, CONFIG_NAME)) tokenizer.save_pretrained(log_dir) # Run the training trainer.run(train_loader, max_epochs=args.n_epochs) # On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) if args.local_rank in [-1, 0] and args.n_epochs > 0: os.rename( os.path.join(log_dir, checkpoint_handler._saved[-1][1]), os.path.join(log_dir, WEIGHTS_NAME) ) # TODO: PR in ignite to have better access to saved file paths (cleaner) tb_logger.close()
class Trainer: def __init__(self, args, train_loader=None, val_loader=None, logger=None, num_answers=0, train=True): self.args = args self.max_text_length = args.max_text_length self.train_loader = train_loader self.val_loader = val_loader self.num_answers = num_answers self.logger = logger # Model self.model = VQAModel.from_pretrained("bert-base-uncased", args=args, num_answers=self.num_answers) self.verbose = True if self.args.distributed: if self.args.gpu != 0: self.verbose = False # Load Checkpoint self.start_epoch = None if args.load is not None: path = args.load + '.pth' self.load(path, verbose=self.verbose) elif args.load_lxmert_qa is not None: path = args.load_lxmert_qa + '_LXRT.pth' load_lxmert_qa( args, path, self.model, label2ans=self.train_loader.dataset.raw_dataset.label2ans, verbose=self.verbose) # GPU Options print(f'Model Launching at GPU {self.args.gpu}') from time import time start = time() self.model.cuda(args.gpu) # Optimizer if train: self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler( ) self.bce_loss = nn.BCEWithLogitsLoss() if args.multiGPU: assert args.distributed self.model = DDP(self.model, device_ids=[args.gpu], find_unused_parameters=True) if args.gpu == 0: print(f'It took {time() - start:.1f}s') # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True) def create_optimizer_and_scheduler(self): if self.verbose: print('Building Optimizer') from transformers.optimization import AdamW, get_linear_schedule_with_warmup batch_per_epoch = len(self.train_loader) t_total = int(batch_per_epoch * self.args.epochs) warmup_ratio = self.args.warmp_ratio warmup_iters = int(t_total * warmup_ratio) if self.verbose: print("Batch per epoch: %d" % batch_per_epoch) print("Total Iters: %d" % t_total) print('Warmup ratio:', warmup_ratio) print("Warm up Iters: %d" % warmup_iters) no_decay = ["bias", "LayerNorm.weight"] optimizer_grouped_parameters = [ { "params": [ p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay) ], "weight_decay": self.args.weight_decay, }, { "params": [ p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay) ], "weight_decay": 0.0, }, ] optim = AdamW(optimizer_grouped_parameters, self.args.lr) lr_scheduler = get_linear_schedule_with_warmup(optim, warmup_iters, t_total) return optim, lr_scheduler def train(self): if self.verbose: loss_meter = LossMeter() quesid2ans = {} best_valid = 0. print("Valid Oracle: %0.2f" % (self.oracle_score(self.val_loader) * 100)) from torch.utils.tensorboard import SummaryWriter self.writer = SummaryWriter(log_dir=self.args.log_dir) hparam_dict = {} for k, v in self.args.__dict__.items(): if type(v) in [int, float, str, bool, torch.Tensor]: hparam_dict[k] = v metric_dict = {} self.writer.add_hparams(hparam_dict, metric_dict) if self.args.distributed: dist.barrier() for epoch in range(self.args.epochs): if self.start_epoch is not None: epoch += self.start_epoch self.model.train() if self.args.distributed: self.train_loader.sampler.set_epoch(epoch) if self.verbose: pbar = tqdm(total=len(self.train_loader), ncols=150) quesid2ans = {} for step_i, batch in enumerate(self.train_loader): update = True if self.args.update_freq > 1: if step_i == 0: update = False elif step_i % self.args.update_freq == 0 or step_i == len( self.train_loader) - 1: update = True else: update = False if self.args.distributed: results = self.model.module.train_step(batch) else: results = self.model.train_step(batch) vis_feats = batch['vis_feats'].cuda() boxes = batch['boxes'].cuda() input_ids = batch['word_ids'].cuda() target = batch['targets'].cuda() ques_id = batch['question_ids'] B = len(batch['word_ids']) results = self.model( input_ids=input_ids, visual_feats=vis_feats, visual_pos=boxes, attention_mask=input_ids > 0, ) logit = results['logit'] assert logit.size() == target.size() assert logit.size() == (B, self.num_answers) loss = self.bce_loss(logit, target) loss.backward() if update: if not self.args.no_clip_grad: nn.utils.clip_grad_norm_(self.model.parameters(), self.args.clip_grad_norm) self.optim.step() self.lr_scheduler.step() for param in self.model.parameters(): param.grad = None try: lr = self.scheduler.get_last_lr()[0] except AttributeError: lr = self.args.lr if self.verbose: loss_meter.update(loss.item()) desc_str = f'Epoch {epoch} | LR {lr:.6f} | ' desc_str += f'Loss {loss_meter.val:4f} |' score, predict = logit.max(1) predict = predict.cpu().numpy() target = target.cpu().numpy() for qid, pred in zip(ques_id, predict): pred_ans = self.train_loader.dataset.raw_dataset.label2ans[ pred] quesid2ans[qid] = pred_ans pbar.set_description(desc_str) pbar.update(1) if self.args.distributed: dist.barrier() # score, label = logit.max(1) # for qid, l in zip(ques_id, label.cpu().numpy()): # ans = dset.label2ans[l] # quesid2ans[qid.item()] = ans if self.verbose: pbar.close() score = self.train_loader.evaluator.evaluate(quesid2ans) * 100. log_str = "\nEpoch %d: Train %0.2f" % (epoch, score) if not self.args.dry: self.writer.add_scalar(f'VQA/Train/score', score, epoch) # Validation valid_score = self.evaluate(self.val_loader) * 100. if valid_score > best_valid: best_valid = valid_score self.save("BEST") log_str += "\nEpoch %d: Valid %0.2f" % (epoch, valid_score) log_str += "\nEpoch %d: Best %0.2f\n" % (epoch, best_valid) if not self.args.dry: self.writer.add_scalar(f'VQA/Valid/score', valid_score, epoch) print(log_str) self.logger.info(log_str) if self.args.distributed: dist.barrier() if self.verbose: self.save("LAST") def predict(self, loader, dump_path=None): """ Predict the answers to questions in a data split. :param eval_tuple: The data tuple to be evaluated. :param dump: The path of saved file to dump results. :return: A dict of question_id to answer. """ self.model.eval() with torch.no_grad(): quesid2ans = {} for i, batch in enumerate( tqdm(loader, ncols=150, desc="Prediction")): vis_feats = batch['vis_feats'].cuda() boxes = batch['boxes'].cuda() input_ids = batch['word_ids'].cuda() ques_id = batch['question_ids'] results = self.model( input_ids=input_ids, visual_feats=vis_feats, visual_pos=boxes, attention_mask=input_ids > 0, ) logit = results['logit'] score, predict = logit.max(1) predict = predict.cpu().numpy() for qid, pred in zip(ques_id, predict): pred_ans = loader.dataset.raw_dataset.label2ans[pred] quesid2ans[qid] = pred_ans if dump_path is not None: loader.evaluator.dump_result(quesid2ans, dump_path) return quesid2ans def evaluate(self, loader, dump_path=None): evaluator = loader.evaluator quesid2ans = self.predict(loader, dump_path) return evaluator.evaluate(quesid2ans) @staticmethod def oracle_score(loader): evaluator = loader.evaluator quesid2ans = {} for i, batch in enumerate(loader): ques_id = batch['question_ids'] label = batch['targets'] _, label = label.max(1) for qid, l in zip(ques_id, label.cpu().numpy()): ans = loader.dataset.raw_dataset.label2ans[l] quesid2ans[qid] = ans return evaluator.evaluate(quesid2ans) def save(self, name): torch.save(self.model.state_dict(), os.path.join(self.output, "%s.pth" % name)) def load(self, path, loc='cpu', verbose=False): state_dict = load_state_dict(path, loc) results = self.model.load_state_dict(state_dict, strict=False) if verbose: print('Loaded from ', path) print(results)
def __init__(self, args, train_loader=None, val_loader=None, logger=None, num_answers=0, train=True): self.args = args self.max_text_length = args.max_text_length self.train_loader = train_loader self.val_loader = val_loader self.num_answers = num_answers self.logger = logger # Model self.model = VQAModel.from_pretrained("bert-base-uncased", args=args, num_answers=self.num_answers) self.verbose = True if self.args.distributed: if self.args.gpu != 0: self.verbose = False # Load Checkpoint self.start_epoch = None if args.load is not None: path = args.load + '.pth' self.load(path, verbose=self.verbose) elif args.load_lxmert_qa is not None: path = args.load_lxmert_qa + '_LXRT.pth' load_lxmert_qa( args, path, self.model, label2ans=self.train_loader.dataset.raw_dataset.label2ans, verbose=self.verbose) # GPU Options print(f'Model Launching at GPU {self.args.gpu}') from time import time start = time() self.model.cuda(args.gpu) # Optimizer if train: self.optim, self.lr_scheduler = self.create_optimizer_and_scheduler( ) self.bce_loss = nn.BCEWithLogitsLoss() if args.multiGPU: assert args.distributed self.model = DDP(self.model, device_ids=[args.gpu], find_unused_parameters=True) if args.gpu == 0: print(f'It took {time() - start:.1f}s') # Output Directory self.output = args.output os.makedirs(self.output, exist_ok=True)
class FastSCNNOperator(BaseOperator): def __init__(self, cfg): super(FastSCNNOperator, self).__init__(cfg) # prepare model for train self.model = Fast_SCNN(input_channel=3, num_classes=self.cfg.Train.num_classes).cuda( self.cfg.Distributed.gpu_id) if self.cfg.Distributed.dist: self.model = DistributedDataParallel( self.model, find_unused_parameters=True, device_ids=[self.cfg.Distributed.gpu_id]) self.optimizer = torch.optim.Adam( params=self.model.parameters(), lr=self.cfg.Train.learning_rate, weight_decay=self.cfg.Train.weight_decay) # self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.9, momentum=0.9) self.loss = torch.nn.CrossEntropyLoss(ignore_index=255) if self.cfg.is_training: self.loader = get_train_loader(name=self.cfg.Dataset.name, cfg=self.cfg) else: self.loader = get_val_loader(name=self.cfg.Dataset.name, cfg=cfg) self.cfg.Train.max_iter_num = self.cfg.Train.epochs * len(self.loader) self.main_flag = self.cfg.Distributed.gpu_id == 0 if self.model.training: self.logger = Logger(self.cfg, self.main_flag) def adjust_lr(self, itr, max_itr): now_lr = self.cfg.Train.learning_rate * ( 1 - itr / (max_itr + 1))**self.cfg.Train.power self.optimizer.param_groups[0]['lr'] = now_lr # self.optimizer.param_groups[1]['lr'] = 10 * now_lr return now_lr def criterion(self, outs, labels): return self.loss(outs, labels) @staticmethod def save_ckp(models, step, path): """ Save checkpoint of the model. :param models: nn.Module :param step: step of the checkpoint. :param path: save path. """ torch.save(models.state_dict(), os.path.join(path, 'ckp-{}.pth'.format(step))) def training_process(self): logger = Logger(self.cfg, self.main_flag) self.model.train() colormap = color_map[self.cfg.Dataset.name] itr = 0 for epoch in range(self.cfg.Train.epochs): for idx, sample in enumerate(self.loader): newlr = self.adjust_lr(itr=itr, max_itr=self.cfg.Train.max_iter_num) xs, ys, names = sample ys = torch.squeeze(ys, dim=1) xs = xs.cuda(self.cfg.Distributed.gpu_id) ys = ys.cuda(self.cfg.Distributed.gpu_id) preds = self.model(xs) loss = self.criterion(preds, ys.long()) self.optimizer.zero_grad() loss.backward() self.optimizer.step() pred_img = seg_vis(preds, colormap) gt_img = seg_vis(ys, colormap) pred_img = torch.from_numpy(pred_img).permute( 2, 0, 1).unsqueeze(0).float() / 255. gt_img = torch.from_numpy(gt_img).permute( 2, 0, 1).unsqueeze(0).float() / 255. if self.main_flag: if itr % self.cfg.Train.print_steps == self.cfg.Train.print_steps - 1: log_data = { 'scalar': { 'Loss': loss.item() }, 'imgs': { 'Pred': [pred_img, gt_img] } } logger.log(log_data, n_iter=itr) itr += 1 if itr % self.cfg.Train.save_model == self.cfg.Train.save_model - 1: if not os.path.exists(self.cfg.Train.ckp_dir): os.mkdir(self.cfg.Train.ckp_dir) if not os.path.exists( os.path.join(self.cfg.Train.ckp_dir, self.cfg.Train.model_name)): os.mkdir( os.path.join(self.cfg.Train.ckp_dir, self.cfg.Train.model_name)) self.save_ckp( self.model.module, itr, os.path.join(self.cfg.Train.ckp_dir, self.cfg.Train.model_name)) def eval_process(self): self.model.eval() self.model.module.load_state_dict(torch.load(self.cfg.Val.model_file)) epoch = 0 step = 0 class_converter = np.array([ 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33 ]).astype(np.uint8) self.loader.sampler.set_epoch(epoch) with torch.no_grad(): for sample in self.loader: xs, ys, names = sample xs = xs.cuda(self.cfg.Distributed.gpu_id) outs = self.model(xs) pred_img = torch.max(outs, dim=1)[1].cpu() pred_img = pred_img.numpy().astype(np.uint8) print(pred_img.max()) print(pred_img.shape) for i in range(pred_img.shape[0]): img = class_converter[pred_img[i]] # pred_img = class_converter[pred_img] im = Image.fromarray(img) img_dir = names[i].split(sep='/')[-2] name = '_'.join(names[i].split(sep='/')[-1].split( sep='_')[:-1]) + '.png' if not os.path.exists( os.path.join(self.cfg.Val.result_dir, img_dir)): os.mkdir(os.path.join(self.cfg.Val.result_dir, img_dir)) im.save( os.path.join(self.cfg.Val.result_dir, img_dir, name)) if self.main_flag: step += 1 print('Step : %d / %d' % (step, len(self.loader))) print('Done !!')
def __init__( self, model: Model, optimizer: torch.optim.Optimizer, data_loader: torch.utils.data.DataLoader, patience: Optional[int] = None, validation_metric: str = "-loss", validation_data_loader: torch.utils.data.DataLoader = None, num_epochs: int = 20, serialization_dir: Optional[str] = None, checkpointer: Checkpointer = None, cuda_device: int = -1, grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None, learning_rate_scheduler: Optional[LearningRateScheduler] = None, momentum_scheduler: Optional[MomentumScheduler] = None, tensorboard_writer: TensorboardWriter = None, moving_average: Optional[MovingAverage] = None, batch_callbacks: List[BatchCallback] = None, epoch_callbacks: List[EpochCallback] = None, distributed: bool = False, local_rank: int = 0, world_size: int = 1, num_gradient_accumulation_steps: int = 1, opt_level: Optional[str] = None, ) -> None: super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size) # I am not calling move_to_gpu here, because if the model is # not already on the GPU then the optimizer is going to be wrong. self.model = model self.data_loader = data_loader self._validation_data_loader = validation_data_loader self.optimizer = optimizer if patience is None: # no early stopping if validation_data_loader: logger.warning( "You provided a validation dataset but patience was set to None, " "meaning that early stopping is disabled") elif (not isinstance(patience, int)) or patience <= 0: raise ConfigurationError( '{} is an invalid value for "patience": it must be a positive integer ' "or None (if you want to disable early stopping)".format( patience)) # For tracking is_best_so_far and should_stop_early self._metric_tracker = MetricTracker(patience, validation_metric) # Get rid of + or - self._validation_metric = validation_metric[1:] self._num_epochs = num_epochs if checkpointer is not None: self._checkpointer = checkpointer else: self._checkpointer = Checkpointer(serialization_dir) self._grad_norm = grad_norm self._grad_clipping = grad_clipping self._learning_rate_scheduler = learning_rate_scheduler self._momentum_scheduler = momentum_scheduler self._moving_average = moving_average self._batch_callbacks = batch_callbacks or [] self._epoch_callbacks = epoch_callbacks or [] # We keep the total batch number as an instance variable because it # is used inside a closure for the hook which logs activations in # `_enable_activation_logging`. self._batch_num_total = 0 self._tensorboard = tensorboard_writer or TensorboardWriter( serialization_dir) self._tensorboard.get_batch_num_total = lambda: self._batch_num_total self._tensorboard.enable_activation_logging(self.model) self._last_log = 0.0 # time of last logging self._num_gradient_accumulation_steps = num_gradient_accumulation_steps # Enable automatic mixed precision training with NVIDIA Apex. self._opt_level = opt_level if self._opt_level is not None: if amp is None: raise ConfigurationError(( "Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable" " automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex." )) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=self._opt_level) # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. # # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the # normal case, reference to `Model` is retained. This reference is only used in # these places: `model.__call__`, `model.train` and `model.eval`. if self._distributed: self._pytorch_model = DistributedDataParallel( self.model, device_ids=[self.cuda_device], find_unused_parameters=True) else: self._pytorch_model = self.model
def train_ai2thor(model, args, rank=0, b=None): seed = args.seed + 10000 *rank torch.manual_seed(seed) np.random.seed(seed) # torch.cuda.set_device(rank) # device = torch.device(f'cuda:{rank}') device = torch.device('cuda' if torch.cuda.is_available() else "cpu") # if torch.cuda.is_available(): # os.environ['DISPLAY'] = f':{rank}' model = model.to(device) model.share_memory() # Experience buffer storage = PPOBuffer(model.obs_shape, args.steps, args.num_workers, args.state_size, args.gamma, device=device) storage.share_memory() #torch.multiprocessing.set_start_method('spawn') # start multiple processes ready_to_works = [Event() for _ in range(args.num_workers)] exit_flag = Value('i', 0) queue = SimpleQueue() processes = [] task_config_file = "config_files/NavTaskTrain.json" # start workers for worker_id in range(args.num_workers): p = Process(target=worker, args=(worker_id, model, storage, ready_to_works[worker_id], queue, exit_flag, task_config_file)) p.start() processes.append(p) # start trainer train_params = {"epochs": args.epochs, "steps": args.steps, "world_size": args.world_size, "num_workers": args.num_workers } ppo_params = {"clip_param": args.clip_param, "train_iters": args.train_iters, "mini_batch_size": args.mini_batch_size, "value_loss_coef": args.value_loss_coef, "entropy_coef": args.entropy_coef, "rnn_steps": args.rnn_steps, "lr": args.lr, "max_kl": args.max_kl } distributed = False if args.world_size > 1: if distributed==True: distributed = True # Initialize Process Group, distributed backend type dist_backend = 'nccl' # Url used to setup distributed training dist_url = "tcp://127.0.0.1:23456" print("Initialize Process Group... pid:", os.getpid()) dist.init_process_group(backend=dist_backend, init_method=dist_url, rank=rank, world_size=args.world_size) # Make model DistributedDataParallel model = DistributedDataParallel(model, device_ids=[rank], output_device=rank) else: print('Distribution Forbidden ~') learner(model, storage, train_params, ppo_params, ready_to_works, queue, exit_flag, rank, distributed, b) for p in processes: print("process ", p.pid, " joined") p.join()
def check_parity(amp: bool, manual_reduction: bool): # The API should be the exact same in between the sharded and non-sharded variants, generic closure def closure(model, scaler, input_tensor, should_accumulate, _manual_reduction=False): accumulate_steps = 3 if should_accumulate else 1 model.zero_grad() def step(): if scaler is not None: with torch.cuda.amp.autocast(): loss = model(input_tensor).abs().sum() scaler.scale(loss).backward() else: loss = model(input_tensor).abs().sum() loss.backward() with model.no_sync() if should_accumulate else suppress(): for _ in range(accumulate_steps - 1): step() if not _manual_reduction: step() else: with model.no_sync(): step() model.reduce() # Any model works. Add one different buffer per rank model = _get_mlp() model.register_buffer("test_buffer", torch.ones((1)) * rank) model.to(device) # Make sure that the model starts with non-trainable, so that we check for the buckets to be # properly reassigned when/if this changes next(model.parameters()).requires_grad = False sharded_optimizer = OSS(params=model.parameters(), optim=torch.optim.SGD, lr=1e-4, momentum=0.99) sharded_ddp_model = ShardedDataParallel( module=model, sharded_optimizer=sharded_optimizer, broadcast_buffers=True, reduce_buffer_size=reduce_buffer_size, reduce_fp16=fp16_reduction, ) ddp_model_single = copy.deepcopy(model) ddp_optimizer = torch.optim.SGD(ddp_model_single.parameters(), lr=1e-4, momentum=0.99) ddp_model = DDP(ddp_model_single, device_ids=[rank], broadcast_buffers=True, find_unused_parameters=True) if fp16_reduction: from dist.algorithms.ddp_com_hooks.default_hooks import fp16_compress_hook ddp_model.register_comm_hook( state=None, hook=fp16_compress_hook) # type: ignore ddp_scaler = TorchGradScaler() if amp else None sharded_ddp_scaler = ShardedGradScaler() if amp else None # The model should be synchronized in between the ranks at construction time, check that check_same_model_params(sharded_ddp_model, ddp_model) # Typical training loop, check that we get the exact same results as DDP for i in range(NUMBER_BATCHS): input_tensor = torch.rand((BATCH_SIZE, 2)).to(device) def closure_ddp(input_tensor=input_tensor): return closure(ddp_model, ddp_scaler, input_tensor, grad_accumulation) def closure_sharded(input_tensor=input_tensor): return closure( sharded_ddp_model, sharded_ddp_scaler, input_tensor, grad_accumulation, _manual_reduction=manual_reduction, ) # Step/scale both if ddp_scaler is not None: _ = closure_ddp(input_tensor) ddp_scaler.step(ddp_optimizer) ddp_scaler.update() else: ddp_optimizer.step(closure=closure_ddp) if sharded_ddp_scaler is not None: _ = closure_sharded(input_tensor) sharded_ddp_scaler.step(sharded_optimizer) sharded_ddp_scaler.update() else: sharded_optimizer.step(closure=closure_sharded) check_same_model_params(sharded_ddp_model, ddp_model, f"Rank: {rank} - Step {i} broke") # Flip the trainability of the first parameter back and forth if i == 0 and change_train_graph: next(sharded_ddp_model.parameters()).requires_grad = not next( sharded_ddp_model.parameters()).requires_grad next(ddp_model.parameters()).requires_grad = not next( ddp_model.parameters()).requires_grad check_same_model_params( sharded_ddp_model, ddp_model, f"Rank: {rank} - Trainability refresh {i} broke")
def main(): # Views the training images and displays the distance on anchor-negative and anchor-positive # test_display_triplet_distance = False # print the experiment configuration print('\nCurrent time is \33[91m{}\33[0m.'.format(str(time.asctime()))) opts = vars(args) keys = list(opts.keys()) keys.sort() options = [] for k in keys: options.append("\'%s\': \'%s\'" % (str(k), str(opts[k]))) print('Parsed options: \n{ %s }' % (', '.join(options))) print('Number of Speakers: {}.\n'.format(train_dir.num_spks)) # instantiate model and initialize weights kernel_size = args.kernel_size.split(',') kernel_size = [int(x) for x in kernel_size] context = args.context.split(',') context = [int(x) for x in context] if args.padding == '': padding = [int((x - 1) / 2) for x in kernel_size] else: padding = args.padding.split(',') padding = [int(x) for x in padding] kernel_size = tuple(kernel_size) padding = tuple(padding) stride = args.stride.split(',') stride = [int(x) for x in stride] channels = args.channels.split(',') channels = [int(x) for x in channels] dilation = args.dilation.split(',') dilation = [int(x) for x in dilation] model_kwargs = { 'input_dim': args.input_dim, 'feat_dim': args.feat_dim, 'kernel_size': kernel_size, 'context': context, 'filter_fix': args.filter_fix, 'dilation': dilation, 'first_2d': args.first_2d, 'mask': args.mask_layer, 'mask_len': args.mask_len, 'block_type': args.block_type, 'filter': args.filter, 'exp': args.exp, 'inst_norm': args.inst_norm, 'input_norm': args.input_norm, 'stride': stride, 'fast': args.fast, 'avg_size': args.avg_size, 'time_dim': args.time_dim, 'padding': padding, 'encoder_type': args.encoder_type, 'vad': args.vad, 'transform': args.transform, 'embedding_size': args.embedding_size, 'ince': args.inception, 'resnet_size': args.resnet_size, 'num_classes': train_dir.num_spks, 'num_classes_b': train_dir.num_doms, 'channels': channels, 'alpha': args.alpha, 'dropout_p': args.dropout_p, 'loss_type': args.loss_type, 'm': args.m, 'margin': args.margin, 's': args.s, 'iteraion': 0, 'all_iteraion': args.all_iteraion } print('Model options: {}'.format(model_kwargs)) dist_type = 'cos' if args.cos_sim else 'l2' print('Testing with %s distance, ' % dist_type) model = create_model(args.model, **model_kwargs) start_epoch = 0 if args.save_init and not args.finetune: check_path = '{}/checkpoint_{}.pth'.format(args.check_path, start_epoch) torch.save(model, check_path) iteration = 0 # if args.resume else 0 if args.finetune and args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] checkpoint_state_dict = checkpoint['state_dict'] if isinstance(checkpoint_state_dict, tuple): checkpoint_state_dict = checkpoint_state_dict[0] filtered = { k: v for k, v in checkpoint_state_dict.items() if 'num_batches_tracked' not in k } if list(filtered.keys())[0].startswith('module'): new_state_dict = OrderedDict() for k, v in filtered.items(): name = k[ 7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module. new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 model.load_state_dict(new_state_dict) else: model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # model.dropout.p = args.dropout_p else: print('=> no checkpoint found at {}'.format(args.resume)) ce_criterion = nn.CrossEntropyLoss() if args.loss_type == 'soft': xe_criterion = None elif args.loss_type == 'asoft': ce_criterion = None xe_criterion = AngleSoftmaxLoss(lambda_min=args.lambda_min, lambda_max=args.lambda_max) elif args.loss_type == 'center': xe_criterion = CenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'variance': xe_criterion = VarianceLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'gaussian': xe_criterion = GaussianLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'coscenter': xe_criterion = CenterCosLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size) elif args.loss_type == 'mulcenter': xe_criterion = MultiCenterLoss(num_classes=train_dir.num_spks, feat_dim=args.embedding_size, num_center=args.num_center) elif args.loss_type == 'amsoft': ce_criterion = None xe_criterion = AMSoftmaxLoss(margin=args.margin, s=args.s) elif args.loss_type == 'arcsoft': ce_criterion = None xe_criterion = ArcSoftmaxLoss(margin=args.margin, s=args.s, iteraion=iteration, all_iteraion=args.all_iteraion) elif args.loss_type == 'wasse': xe_criterion = Wasserstein_Loss(source_cls=args.source_cls) elif args.loss_type == 'ring': xe_criterion = RingLoss(ring=args.ring) args.alpha = 0.0 model_para = [{'params': model.parameters()}] if args.loss_type in [ 'center', 'variance', 'mulcenter', 'gaussian', 'coscenter', 'ring' ]: assert args.lr_ratio > 0 model_para.append({ 'params': xe_criterion.parameters(), 'lr': args.lr * args.lr_ratio }) if args.finetune or args.second_wd > 0: # if args.loss_type in ['asoft', 'amsoft']: classifier_params = list(map(id, model.classifier.parameters())) rest_params = filter(lambda p: id(p) not in classifier_params, model.parameters()) init_lr = args.lr * args.lr_ratio if args.lr_ratio > 0 else args.lr init_wd = args.second_wd if args.second_wd > 0 else args.weight_decay print('Set the lr and weight_decay of classifier to %f and %f' % (init_lr, init_wd)) model_para = [{ 'params': rest_params }, { 'params': model.classifier.parameters(), 'lr': init_lr, 'weight_decay': init_wd }] if args.filter in ['fDLR', 'fBLayer', 'fLLayer', 'fBPLayer']: filter_params = list(map(id, model.filter_layer.parameters())) rest_params = filter(lambda p: id(p) not in filter_params, model_para[0]['params']) init_wd = args.filter_wd if args.filter_wd > 0 else args.weight_decay init_lr = args.lr * args.lr_ratio if args.lr_ratio > 0 else args.lr print('Set the lr and weight_decay of filter layer to %f and %f' % (init_lr, init_wd)) model_para[0]['params'] = rest_params model_para.append({ 'params': model.filter_layer.parameters(), 'lr': init_lr, 'weight_decay': init_wd }) optimizer = create_optimizer(model_para, args.optimizer, **opt_kwargs) if not args.finetune and args.resume: if os.path.isfile(args.resume): print('=> loading checkpoint {}'.format(args.resume)) checkpoint = torch.load(args.resume) start_epoch = checkpoint['epoch'] checkpoint_state_dict = checkpoint['state_dict'] if isinstance(checkpoint_state_dict, tuple): checkpoint_state_dict = checkpoint_state_dict[0] filtered = { k: v for k, v in checkpoint_state_dict.items() if 'num_batches_tracked' not in k } # filtered = {k: v for k, v in checkpoint['state_dict'].items() if 'num_batches_tracked' not in k} if list(filtered.keys())[0].startswith('module'): new_state_dict = OrderedDict() for k, v in filtered.items(): name = k[ 7:] # remove `module.`,表面从第7个key值字符取到最后一个字符,去掉module. new_state_dict[name] = v # 新字典的key值对应的value为一一对应的值。 model.load_state_dict(new_state_dict) else: model_dict = model.state_dict() model_dict.update(filtered) model.load_state_dict(model_dict) # model.dropout.p = args.dropout_p else: print('=> no checkpoint found at {}'.format(args.resume)) # Save model config txt with open( osp.join( args.check_path, 'model.%s.conf' % time.strftime("%Y.%m.%d", time.localtime())), 'w') as f: f.write('model: ' + str(model) + '\n') f.write('CrossEntropy: ' + str(ce_criterion) + '\n') f.write('Other Loss: ' + str(xe_criterion) + '\n') f.write('Optimizer: ' + str(optimizer) + '\n') milestones = args.milestones.split(',') milestones = [int(x) for x in milestones] milestones.sort() if args.scheduler == 'exp': scheduler = lr_scheduler.ExponentialLR(optimizer, gamma=args.gamma) elif args.scheduler == 'rop': scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, patience=args.patience, min_lr=1e-5) elif args.scheduler == 'cyclic': cycle_momentum = False if args.optimizer == 'adam' else True scheduler = lr_scheduler.CyclicLR(optimizer, base_lr=1e-8, max_lr=args.lr, step_size_up=13000, cycle_momentum=cycle_momentum) else: scheduler = lr_scheduler.MultiStepLR(optimizer, milestones=milestones, gamma=0.1) ce = [ce_criterion, xe_criterion] start = args.start_epoch + start_epoch print('Start epoch is : ' + str(start)) # start = 0 end = start + args.epochs if len(args.random_chunk ) == 2 and args.random_chunk[0] < args.random_chunk[1]: min_chunk_size = int(args.random_chunk[0]) max_chunk_size = int(args.random_chunk[1]) pad_dim = 2 if args.feat_format == 'kaldi' else 3 train_loader = torch.utils.data.DataLoader( train_dir, batch_size=args.batch_size, collate_fn=PadCollate( dim=pad_dim, num_batch=int(np.ceil(len(train_dir) / args.batch_size)), min_chunk_size=min_chunk_size, max_chunk_size=max_chunk_size), shuffle=args.shuffle, **kwargs) valid_loader = torch.utils.data.DataLoader( valid_dir, batch_size=int(args.batch_size / 2), collate_fn=PadCollate(dim=pad_dim, fix_len=True, min_chunk_size=args.chunk_size, max_chunk_size=args.chunk_size + 1), shuffle=False, **kwargs) else: train_loader = torch.utils.data.DataLoader(train_dir, batch_size=args.batch_size, shuffle=args.shuffle, **kwargs) valid_loader = torch.utils.data.DataLoader(valid_dir, batch_size=int( args.batch_size / 2), shuffle=False, **kwargs) train_extract_loader = torch.utils.data.DataLoader(train_extract_dir, batch_size=1, shuffle=False, **extract_kwargs) if args.cuda: if len(args.gpu_id) > 1: print("Continue with gpu: %s ..." % str(args.gpu_id)) torch.distributed.init_process_group( backend="nccl", # init_method='tcp://localhost:23456', init_method= 'file:///home/ssd2020/yangwenhao/lstm_speaker_verification/data/sharedfile', rank=0, world_size=1) # if args.gain # model = DistributedDataParallel(model.cuda(), find_unused_parameters=True) model = DistributedDataParallel(model.cuda()) else: model = model.cuda() for i in range(len(ce)): if ce[i] != None: ce[i] = ce[i].cuda() try: print('Dropout is {}.'.format(model.dropout_p)) except: pass xvector_dir = args.check_path xvector_dir = xvector_dir.replace('checkpoint', 'xvector') start_time = time.time() try: for epoch in range(start, end): # pdb.set_trace() lr_string = '\n\33[1;34m Current \'{}\' learning rate is '.format( args.optimizer) for param_group in optimizer.param_groups: lr_string += '{:.10f} '.format(param_group['lr']) print('%s \33[0m' % lr_string) train(train_loader, model, ce, optimizer, epoch, scheduler) valid_loss = valid_class(valid_loader, model, ce, epoch) if (epoch == 1 or epoch != (end - 2)) and (epoch % 4 == 1 or epoch in milestones or epoch == (end - 1)): model.eval() check_path = '{}/checkpoint_{}.pth'.format( args.check_path, epoch) model_state_dict = model.module.state_dict() \ if isinstance(model, DistributedDataParallel) else model.state_dict() torch.save( { 'epoch': epoch, 'state_dict': model_state_dict, 'criterion': ce }, check_path) valid_test(train_extract_loader, model, epoch, xvector_dir) test(model, epoch, writer, xvector_dir) if epoch != (end - 1): try: shutil.rmtree("%s/train/epoch_%s" % (xvector_dir, epoch)) shutil.rmtree("%s/test/epoch_%s" % (xvector_dir, epoch)) except Exception as e: print('rm dir xvectors error:', e) if args.scheduler == 'rop': scheduler.step(valid_loss) elif args.scheduler == 'cyclic': continue else: scheduler.step() except KeyboardInterrupt: end = epoch writer.close() stop_time = time.time() t = float(stop_time - start_time) print("Running %.4f minutes for each epoch.\n" % (t / 60 / (max(end - start, 1)))) exit(0)
class nnUNetTrainerV2BraTSRegions_DDP(nnUNetTrainerV2_DDP): def __init__(self, plans_file, fold, local_rank, output_folder=None, dataset_directory=None, batch_dice=True, stage=None, unpack_data=True, deterministic=True, distribute_batch_size=False, fp16=False): super().__init__(plans_file, fold, local_rank, output_folder, dataset_directory, batch_dice, stage, unpack_data, deterministic, distribute_batch_size, fp16) self.regions = get_brats_regions() self.regions_class_order = (1, 2, 3) self.loss = None self.ce_loss = nn.BCEWithLogitsLoss() def process_plans(self, plans): super().process_plans(plans) """ The network has as many outputs as we have regions """ self.num_classes = len(self.regions) def initialize_network(self): """inference_apply_nonlin to sigmoid""" super().initialize_network() self.network.inference_apply_nonlin = nn.Sigmoid() def initialize(self, training=True, force_load_plans=False): """ this is a copy of nnUNetTrainerV2's initialize. We only add the regions to the data augmentation :param training: :param force_load_plans: :return: """ if not self.was_initialized: maybe_mkdir_p(self.output_folder) if force_load_plans or (self.plans is None): self.load_plans_file() self.process_plans(self.plans) self.setup_DA_params() self.folder_with_preprocessed_data = join(self.dataset_directory, self.plans['data_identifier'] + "_stage%d" % self.stage) if training: self.dl_tr, self.dl_val = self.get_basic_generators() if self.unpack_data: if self.local_rank == 0: print("unpacking dataset") unpack_dataset(self.folder_with_preprocessed_data) print("done") else: # we need to wait until worker 0 has finished unpacking npz_files = subfiles(self.folder_with_preprocessed_data, suffix=".npz", join=False) case_ids = [i[:-4] for i in npz_files] all_present = all( [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids]) while not all_present: print("worker", self.local_rank, "is waiting for unpacking") sleep(3) all_present = all( [isfile(join(self.folder_with_preprocessed_data, i + ".npy")) for i in case_ids]) # there is some slight chance that there may arise some error because dataloader are loading a file # that is still being written by worker 0. We ignore this for now an address it only if it becomes # relevant # (this can occur because while worker 0 writes the file is technically present so the other workers # will proceed and eventually try to read it) else: print( "INFO: Not unpacking data! Training may be slow due to that. Pray you are not using 2d or you " "will wait all winter for your model to finish!") # setting weights for deep supervision losses net_numpool = len(self.net_num_pool_op_kernel_sizes) # we give each output a weight which decreases exponentially (division by 2) as the resolution decreases # this gives higher resolution outputs more weight in the loss weights = np.array([1 / (2 ** i) for i in range(net_numpool)]) # we don't use the lowest 2 outputs. Normalize weights so that they sum to 1 mask = np.array([True if i < net_numpool - 1 else False for i in range(net_numpool)]) weights[~mask] = 0 weights = weights / weights.sum() self.ds_loss_weights = weights seeds_train = np.random.random_integers(0, 99999, self.data_aug_params.get('num_threads')) seeds_val = np.random.random_integers(0, 99999, max(self.data_aug_params.get('num_threads') // 2, 1)) print("seeds train", seeds_train) print("seeds_val", seeds_val) self.tr_gen, self.val_gen = get_moreDA_augmentation(self.dl_tr, self.dl_val, self.data_aug_params[ 'patch_size_for_spatialtransform'], self.data_aug_params, deep_supervision_scales=self.deep_supervision_scales, seeds_train=seeds_train, seeds_val=seeds_val, pin_memory=self.pin_memory, regions=self.regions) self.print_to_log_file("TRAINING KEYS:\n %s" % (str(self.dataset_tr.keys())), also_print_to_console=False) self.print_to_log_file("VALIDATION KEYS:\n %s" % (str(self.dataset_val.keys())), also_print_to_console=False) else: pass self.initialize_network() self.initialize_optimizer_and_scheduler() self._maybe_init_amp() self.network = DDP(self.network, self.local_rank) else: self.print_to_log_file('self.was_initialized is True, not running self.initialize again') self.was_initialized = True def validate(self, do_mirroring: bool = True, use_sliding_window: bool = True, step_size: int = 0.5, save_softmax: bool = True, use_gaussian: bool = True, overwrite: bool = True, validation_folder_name: str = 'validation_raw', debug: bool = False, all_in_gpu: bool = False, segmentation_export_kwargs: dict = None): super().validate(do_mirroring=do_mirroring, use_sliding_window=use_sliding_window, step_size=step_size, save_softmax=save_softmax, use_gaussian=use_gaussian, overwrite=overwrite, validation_folder_name=validation_folder_name, debug=debug, all_in_gpu=all_in_gpu, segmentation_export_kwargs=segmentation_export_kwargs) # run brats specific validation output_folder = join(self.output_folder, validation_folder_name) evaluate_regions(output_folder, self.gt_niftis_folder, self.regions) def run_iteration(self, data_generator, do_backprop=True, run_online_evaluation=False): raise NotImplementedError("this class has not been changed to work with pytorch amp yet!") data_dict = next(data_generator) data = data_dict['data'] target = data_dict['target'] data = maybe_to_torch(data) target = maybe_to_torch(target) if torch.cuda.is_available(): data = to_cuda(data, gpu_id=None) target = to_cuda(target, gpu_id=None) self.optimizer.zero_grad() output = self.network(data) del data total_loss = None for i in range(len(output)): # Starting here it gets spicy! axes = tuple(range(2, len(output[i].size()))) # network does not do softmax. We need to do softmax for dice output_softmax = torch.sigmoid(output[i]) # get the tp, fp and fn terms we need tp, fp, fn, _ = get_tp_fp_fn_tn(output_softmax, target[i], axes, mask=None) # for dice, compute nominator and denominator so that we have to accumulate only 2 instead of 3 variables # do_bg=False in nnUNetTrainer -> [:, 1:] nominator = 2 * tp[:, 1:] denominator = 2 * tp[:, 1:] + fp[:, 1:] + fn[:, 1:] if self.batch_dice: # for DDP we need to gather all nominator and denominator terms from all GPUS to do proper batch dice nominator = awesome_allgather_function.apply(nominator) denominator = awesome_allgather_function.apply(denominator) nominator = nominator.sum(0) denominator = denominator.sum(0) else: pass ce_loss = self.ce_loss(output[i], target[i]) # we smooth by 1e-5 to penalize false positives if tp is 0 dice_loss = (- (nominator + 1e-5) / (denominator + 1e-5)).mean() if total_loss is None: total_loss = self.ds_loss_weights[i] * (ce_loss + dice_loss) else: total_loss += self.ds_loss_weights[i] * (ce_loss + dice_loss) if run_online_evaluation: with torch.no_grad(): output = output[0] target = target[0] out_sigmoid = torch.sigmoid(output) out_sigmoid = (out_sigmoid > 0.5).float() if self.threeD: axes = (2, 3, 4) else: axes = (2, 3) tp, fp, fn, _ = get_tp_fp_fn_tn(out_sigmoid, target, axes=axes) tp_hard = awesome_allgather_function.apply(tp) fp_hard = awesome_allgather_function.apply(fp) fn_hard = awesome_allgather_function.apply(fn) # print_if_rank0("after allgather", tp_hard.shape) # print_if_rank0("after sum", tp_hard.shape) self.run_online_evaluation(tp_hard.detach().cpu().numpy().sum(0), fp_hard.detach().cpu().numpy().sum(0), fn_hard.detach().cpu().numpy().sum(0)) del target if do_backprop: if not self.fp16 or amp is None or not torch.cuda.is_available(): total_loss.backward() else: with amp.scale_loss(total_loss, self.optimizer) as scaled_loss: scaled_loss.backward() _ = clip_grad_norm_(self.network.parameters(), 12) self.optimizer.step() return total_loss.detach().cpu().numpy() def run_online_evaluation(self, tp, fp, fn): self.online_eval_foreground_dc.append(list((2 * tp) / (2 * tp + fp + fn + 1e-8))) self.online_eval_tp.append(list(tp)) self.online_eval_fp.append(list(fp)) self.online_eval_fn.append(list(fn))
def train(): args = get_args() '''Setup''' if not os.path.exists(args.log_path): os.makedirs(args.log_path, exist_ok=True) # logging is set to INFO (resp. WARN) for main (resp. auxiliary) process. logger.info => log main process only, logger.warning => log all processes logging.basicConfig( level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN) # This is a logger.warning: it will be printed by all distributed processes logger.warning("Running process %d", args.local_rank) logger.info("Arguments: %s", pformat(args)) '''Initialize distributed training if needed''' args.distributed = (args.local_rank != -1) if args.distributed: torch.cuda.set_device(args.local_rank) args.device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') logger.info( "Prepare tokenizer, pretrained model and optimizer - add special tokens for fine-tuning" ) tokenizer_class = GPT2Tokenizer tokenizer = tokenizer_class.from_pretrained(args.model_checkpoint) model_class = VideoGPT2LMHeadModel model = model_class.from_pretrained(args.model_checkpoint) tokenizer.add_special_tokens(SPECIAL_TOKENS_DICT) model.resize_token_embeddings(len(tokenizer)) model.to(args.device) optimizer = AdamW(model.parameters(), lr=args.lr) ''' Prepare model for FP16 and distributed training if needed (order is important, distributed should be the last) ''' if args.fp16: from apex import amp # Apex is only required if we use fp16 training model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16) if args.distributed: model = DistributedDataParallel(model, device_ids=[args.local_rank], output_device=args.local_rank) model = model.module logger.info("Prepare datasets") train_loader, val_loader = get_data_loaders_new(args, tokenizer) '''Training function and trainer''' def update(engine, batch): model.train() batch = tuple(input_tensor.to(args.device) for input_tensor in batch) input_ids, token_type_ids, labels, input_mask, i3d, video_mask, reply_mask = batch input_embs = model.transformer.wte(input_ids) video_embs = model.video_ff(i3d) input_embs = torch.cat([video_embs, input_embs], dim=1) token_type_ids = torch.cat([ torch.ones((i3d.size(0), i3d.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids ], dim=1) video_loss = model(input_embs, token_type_ids=token_type_ids, labels=(labels, i3d), attention_mask=[video_mask, input_mask], mode="video")[0] reply_loss = model(input_embs, token_type_ids=token_type_ids, labels=(labels, i3d), attention_mask=[reply_mask, input_mask], mode="reply")[0] loss = (video_loss + reply_loss) / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_norm) else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm) if engine.state.iteration % args.gradient_accumulation_steps == 0: optimizer.step() optimizer.zero_grad() return loss.item() '''Evaluation function and evaluator (evaluator output is the input of the metrics)''' def inference(engine, batch): model.eval() with torch.no_grad(): batch = tuple( input_tensor.to(args.device) for input_tensor in batch) input_ids, token_type_ids, lm_labels, input_mask, i3d, video_mask, reply_mask = batch input_embs = model.transformer.wte(input_ids) video_embs = model.video_ff(i3d) input_embs = torch.cat([video_embs, input_embs], dim=1) token_type_ids = torch.cat([ torch.ones((i3d.size(0), i3d.size(1))).long().cuda() * tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS[-2]), token_type_ids ], dim=1) model_outputs = model(input_embs, token_type_ids=token_type_ids, attention_mask=[reply_mask, input_mask])[0] lm_logits = model_outputs # So we can also use GPT2 outputs lm_logits_flat_shifted = lm_logits[..., :-1, :].contiguous().view( -1, lm_logits.size(-1)) lm_labels_flat_shifted = lm_labels[..., 1:].contiguous().view(-1) return lm_logits_flat_shifted, lm_labels_flat_shifted '''Engines''' trainer = Engine(update) evaluator = Engine(inference) ''' Attach evaluation to trainer: we evaluate when we start the training and at the end of each epoch ''' trainer.add_event_handler(Events.EPOCH_COMPLETED, lambda _: evaluator.run(val_loader)) if args.n_epochs < 1: trainer.add_event_handler(Events.COMPLETED, lambda _: evaluator.run(val_loader)) if args.eval_before_start: trainer.add_event_handler(Events.STARTED, lambda _: evaluator.run(val_loader)) # Linearly decrease the learning rate from lr to zero scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)]) trainer.add_event_handler(Events.ITERATION_STARTED, scheduler) # Prepare metrics - note how we compute distributed metrics RunningAverage(output_transform=lambda x: x).attach(trainer, "loss") metrics = { "nll": Loss(torch.nn.CrossEntropyLoss(ignore_index=-1), output_transform=lambda x: (x[0], x[1])) } metrics.update({ "average_nll": MetricsLambda(average_distributed_scalar, metrics["nll"], args) }) metrics["average_ppl"] = MetricsLambda(math.exp, metrics["average_nll"]) for name, metric in metrics.items(): metric.attach(evaluator, name) ''' On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train ''' if args.local_rank in [-1, 0]: pbar = ProgressBar(persist=True) pbar.attach(trainer, metric_names=["loss"]) evaluator.add_event_handler( Events.COMPLETED, lambda _: pbar.log_message( "Validation: %s" % pformat(evaluator.state.metrics))) tb_logger = TensorboardLogger(log_dir="./tb_logs") tb_logger.attach(trainer, log_handler=OutputHandler(tag="training", metric_names=["loss"]), event_name=Events.ITERATION_COMPLETED) tb_logger.attach(trainer, log_handler=OptimizerParamsHandler(optimizer), event_name=Events.ITERATION_STARTED) tb_logger.attach(evaluator, log_handler=OutputHandler(tag="validation", metric_names=list( metrics.keys()), another_engine=trainer), event_name=Events.EPOCH_COMPLETED) checkpoint_handler = ModelCheckpoint(args.log_path, 'checkpoint', n_saved=8, require_empty=False) trainer.add_event_handler(Events.EPOCH_COMPLETED(every=1), checkpoint_handler, {'mymodel': getattr(model, 'module', model)}) # "getattr" take care of distributed encapsulation torch.save(args, args.log_path + 'model_training_args.bin') getattr(model, 'module', model).config.to_json_file( os.path.join(args.log_path, CONFIG_NAME)) tokenizer.save_vocabulary(args.log_path) '''Run the training''' trainer.run(train_loader, max_epochs=args.n_epochs) ''' On the main process: close tensorboard logger and rename the last checkpoint (for easy re-loading with OpenAIGPTModel.from_pretrained method) ''' if args.local_rank in [-1, 0] and args.n_epochs > 0: # TODO: PR in ignite to have better access to saved file paths (cleaner) os.rename(checkpoint_handler._saved[-1][1][-1], os.path.join(args.log_path, WEIGHTS_NAME)) tb_logger.close()
def train(local_rank, args): torch.backends.cudnn.benchmark = True import os # torch.multiprocessing.set_sharing_strategy('file_system') # too many barriers / one node data parallel and multiple node DDP os.environ['MASTER_ADDR'] = args["master_addr"] os.environ['MASTER_PORT'] = args["master_port"] os.environ["NCCL_DEBUG"] = "WARN" # os.environ["CUDA_VISIBLE_DEVICES"] = str(local_rank) # gpu_device = 0 gpu_device = local_rank os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" if args["wandb_dryrun"]: os.environ["WANDB_MODE"] = "dryrun" os.environ["WANDB_SILENT"] = "true" os.environ['TOKENIZERS_PARALLELISM'] = "true" torch.backends.cudnn.benchmark = True rank = args["nr"] if args["cpu"] else (args["nr"] * args["gpus_per_node"] + local_rank) nr = args["nr"] if args["cpu"]: assert local_rank == 0 device = torch.device("cpu") args["dist_backend"] = "gloo" # init_method = "tcp://%s:%s" % ("127.0.0.1", "9999") else: device = torch.device(f'cuda:{gpu_device}') # Unique only on individual node. torch.cuda.set_device(device) if args["init_method"] == "tcp": if args["nr"] == 0: args["master_addr"] = "0.0.0.0" init_method="tcp://%s:%s" % (args["master_addr"], args["master_port"]) elif args["init_method"] == "file": init_method = 'file://%s/%s' % (args["master_addr"], args["master_port"]) else: raise ValueError rnd = torch.tensor(0.0, device="cpu") if args["world_size"] > 1: dist.init_process_group(args["dist_backend"], rank=rank, world_size=args["world_size"], init_method=init_method) rnd = torch.tensor(int(time.time())).to(device) dist.broadcast(rnd, 0) barrier = get_barrier(args["world_size"] > 1) format = "%Y-%m-%d %H-%M %Z" # + timedelta(hours=5, minutes=30) time_string = (datetime.fromtimestamp(time.mktime(time.gmtime(rnd.cpu().item())))).astimezone(timezone('Asia/Kolkata')).strftime(format) ds_name = list(filter(lambda x: len(x.strip()) > 0, args["dataset"].split("/")))[-1].replace("train_fastformer_resampled_", "") set_seeds(args["seed"]) batch_size = 8 optimizer_config.lr = args["lr"] optimizer_config.weight_decay = args["weight_decay"] optimizer_config.gradient_clipping = args["gradient_clipping"] optimizer_config.beta_1 = args["beta_1"] optimizer_config.beta_2 = args["beta_2"] eps = 1e-4 if args["no_autocast"]: optimizer_config.eps = 1e-7 eps = 1e-7 reinit = args["pretrained_model"] is None or "pretrained_model" not in args or args["pretrained_model"] == "" backbone, tokenizer = get_mtt_backbone(args["model_config"], args["cls_tokens"], args["enable_layer_normalizers"], args["sampling_alpha"], reinit, args["enable_layer_normalizers"], args["enable_layer_normalizers_statistics"], dropout_prob=0.01) teacher_backbone, _ = get_mtt_backbone(args["model_config"], args["cls_tokens"], args["enable_layer_normalizers"], None, reinit, args["enable_layer_normalizers"], args["enable_layer_normalizers_statistics"], dropout_prob=0.0) batch_size = args["batch_size"] if "batch_size" in args and isinstance(args["batch_size"], int) else batch_size generator_w = args["generator_w"] if "generator_w" in args else 0.0 discriminator_w = args["discriminator_w"] if "discriminator_w" in args else 0.0 dino_w = args["dino_w"] if "dino_w" in args else 0.0 sentence_order_prediction_w = args["sentence_order_prediction_w"] if "sentence_order_prediction_w" in args else 0.0 attention_penalty_w = args["attention_penalty_w"] if "attention_penalty_w" in args else 0.0 student = MTTModel(backbone, tokenizer, args["cls_tokens"], generator_w=generator_w, discriminator_w=discriminator_w, dino_w=dino_w, sentence_order_prediction_w=sentence_order_prediction_w, attention_penalty_w=attention_penalty_w, lm_layers=args["lm_layers"], electra_layers=args["electra_layers"], lm_layers_total=args["lm_layers_total"], electra_layers_total=args["electra_layers_total"], drop_unused_layers=args["drop_unused_layers"], approximate_unused_layers=args["consecutive_layers"], exclude_layers=args["exclude_layers"], keep_last_layer=args["keep_last_layer"], lm_temperature=args["lm_temperature"]) teacher = MTTModel(teacher_backbone, tokenizer, args["cls_tokens"], generator_w=generator_w, discriminator_w=discriminator_w, dino_w=1.0, sentence_order_prediction_w=sentence_order_prediction_w, attention_penalty_w=0.0, lm_layers=None, electra_layers=None, lm_layers_total=args["lm_layers_total"], electra_layers_total=args["electra_layers_total"], lm_temperature=args["lm_temperature"]) teacher = teacher.eval() model = MultiTaskHighwayCLSPretraining(student, teacher, eps, device if args["move_unused_to_cpu"] else None).to(device) trainable_model = student if dino_w == 0: model.teacher = None teacher = None clean_memory() del teacher if local_rank == 0 and rank == 0: print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), numel(trainable_model) / 1_000_000)) if args["pretrained_model"] is not None and os.path.exists(args["pretrained_model"]): state_dict = torch.load(args["pretrained_model"], map_location='cpu' if args['cpu'] else 'cuda:%d' % gpu_device) try: trainable_model.load_state_dict(state_dict, strict=True) load_type = "strict" except: try: try: trainable_model.load_state_dict(state_dict, strict=False) load_type = "not_strict" except: state_dict = {k: v for k, v in state_dict.items() if k.startswith("backbone.")} trainable_model.load_state_dict(state_dict, strict=False) load_type = "not_strict_no_ffn" except: try: try: state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} trainable_model.load_state_dict(state_dict, strict=True) load_type = "strict-from-ddp" except: state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} state_dict = {k: v for k, v in state_dict.items() if not k.startswith("backbone.")} trainable_model.load_state_dict(state_dict, strict=True) load_type = "strict-from-ddp-no-ffn" except: try: state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} trainable_model.load_state_dict(state_dict, strict=False) load_type = "not_strict-from-ddp" except: state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} state_dict = {k: v for k, v in state_dict.items() if not k.startswith("backbone.")} trainable_model.load_state_dict(state_dict, strict=False) load_type = "not_strict-from-ddp-no-ffn" if dino_w > 0: student_teacher_param_update(model.student, model.teacher, 0.001, device if args["move_unused_to_cpu"] else None) print("[Train]: Time = %s, Loaded Pretrained model with Load type = %s, Torch Version = %s" % (get_time_string(), load_type, torch.__version__)) del state_dict model = model.train() # print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), {k for k, v in model.named_parameters() if v.requires_grad})) if args["world_size"] > 1: # model = FSDP(model, **fsdp_params) # find_unused_parameters=True trainable_model = DDP(trainable_model, device_ids=None if args["cpu"] else [gpu_device], find_unused_parameters=True, bucket_cap_mb=50) # find_unused_parameters=True model.student = trainable_model if dino_w > 0: model.teacher = model.teacher.eval() student_teacher_param_update(model.student, model.teacher, 0.01, device if args["move_unused_to_cpu"] else None) try: from torch.distributed.algorithms.ddp_comm_hooks.default_hooks import fp16_compress_hook trainable_model.register_comm_hook(state=None, hook=fp16_compress_hook) except: print("[Train]: Time = %s, No fp16_compress_hook present, Torch Version = %s" % (get_time_string(), torch.__version__)) del backbone del teacher_backbone del student clean_memory() barrier() optc = optimizer_config.to_dict() trainable_params = list(filter(lambda p: p.requires_grad, trainable_model.parameters())) if args["optimizer"] == "adamw": optimizer = torch.optim.AdamW(trainable_params, lr=optc["lr"], eps=optc["eps"], weight_decay=optc["weight_decay"], betas=(optc["beta_1"], optc["beta_2"])) elif args["optimizer"] == "sgd": optimizer = torch.optim.SGD(trainable_params, lr=optc["lr"], momentum=0.9, weight_decay=optc["weight_decay"], nesterov=True) elif args["optimizer"] == "novograd": optimizer = Novograd(trainable_params, lr=optc["lr"], eps=optc["eps"], betas=(optc["beta_1"], optc["beta_2"]), weight_decay=optc["weight_decay"],) elif args["optimizer"] == "rangerlars": optimizer = RangerLars(trainable_params, lr=optc["lr"], eps=optc["eps"], betas=(optc["beta_1"], optc["beta_2"]), weight_decay=optc["weight_decay"],) else: raise ValueError # print("[Train]: Time = %s, Trainable Params = %s" % (get_time_string(), {k for k, v in trainable_model.named_parameters() if v.requires_grad})) del trainable_params optimizer.zero_grad(set_to_none=True) model_save_dir = args["model_save_dir"] model_save_name = args["model_save_name"] set_seeds(args["seed"] + rank) if local_rank == 0: if not os.path.exists(model_save_dir): os.makedirs(model_save_dir) assert os.path.exists(model_save_dir) try: dataloader = build_dataloader(os.path.join(args["dataset"], "all_512_only"), args["shuffle_dataset"], batch_size, tokenizer, args["cls_tokens"], world_size=args["world_size"], num_workers=args["num_workers"], max_length=512) dataloader128 = build_dataloader(os.path.join(args["dataset"], "all_128_only"), args["shuffle_dataset"], batch_size * 6, tokenizer, args["cls_tokens"], world_size=args["world_size"], num_workers=args["num_workers"], max_length=128) dataloader256 = build_dataloader(os.path.join(args["dataset"], "all_256_only"), args["shuffle_dataset"], batch_size * 3, tokenizer, args["cls_tokens"], world_size=args["world_size"], num_workers=args["num_workers"], max_length=256) except: print("[WARN] [Train]: Time = %s, All dataloaders and datasets are same = %s" % (get_time_string(), args["dataset"])) dataloader = build_dataloader(args["dataset"], args["shuffle_dataset"], batch_size, tokenizer, args["cls_tokens"], world_size=args["world_size"], num_workers=args["num_workers"], max_length=512) dataloader128 = build_dataloader(args["dataset"], args["shuffle_dataset"], batch_size * 4, tokenizer, args["cls_tokens"], world_size=args["world_size"], num_workers=args["num_workers"], max_length=128) dataloader256 = build_dataloader(args["dataset"], args["shuffle_dataset"], batch_size * 2, tokenizer, args["cls_tokens"], world_size=args["world_size"], num_workers=args["num_workers"], max_length=256) iter_size = max(args["accumulation_steps"], 1) no_sync = iter_size > 1 steps_per_epoch = int(np.ceil(len(dataloader.sampler) / (batch_size * iter_size)) if dataloader.sampler is not None else (len(dataloader) / iter_size)) if local_rank == 0: print("[Train]: Time = %s, Optimizer and Scheduler Initialised, max lr = %.5f, steps_per_epoch = %s, batch size = %s, dataloader length = %s, Sampler Present = %s, Sampler Length = %s" % (get_time_string(), optc["lr"], steps_per_epoch, batch_size, len(dataloader), dataloader.sampler is not None, len(dataloader.sampler) if dataloader.sampler is not None else -1)) dataloader = get_next(dataloader) dataloader128 = get_next(dataloader128) dataloader256 = get_next(dataloader256) log_every_steps = args["log_every_steps"] * iter_size save_every_steps = args["save_every_steps"] # scheduler = optimization.get_constant_schedule_with_warmup(optimizer, optc["warmup_steps"]) # scheduler = optimization.get_linear_schedule_with_warmup(optimizer, optc["warmup_steps"], args["epochs"] * len(dataloader)) div_factor = optc["lr"]/1e-6 scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, optc["lr"], total_steps=args["total_steps"], div_factor=div_factor, three_phase=False, pct_start=0.06, anneal_strategy="linear", cycle_momentum=False) # scheduler1 = optimization.get_constant_schedule_with_warmup(optimizer, optc["warmup_steps"]) # scheduler2 = torch.optim.lr_scheduler.StepLR(optimizer, step_size=(steps_per_epoch * args["epochs"]) // args["lr_steps"], gamma=0.5) # scheduler = [scheduler1, scheduler2] barrier() gradient_clipping = optc["gradient_clipping"] group = "%s-%s-%s-%sN-%s" % (args["wandb_name"], ds_name, args["model_config"], args["nodes"], time_string) wandb_init_args = dict(project="de_lm", name="%s-%s-%s-%s" % (group, args["nr"], rank, local_rank), group=group, id=f"{group}-worker-{nr}-{rank}-{local_rank}", config={"args":args, "optimizer_config": optc}, settings=wandb.Settings(start_method="fork")) time.sleep(random.random()) wandb.init(**wandb_init_args) full_times = [] batch_times = [] model_times = [] model.zero_grad(set_to_none=True) samples_processed = 0 samples_processed_this_log_iter = 0 if args["detect_anomaly"]: torch.autograd.set_detect_anomaly(True) def hook(grad): is_nan_inf = torch.logical_not(torch.isfinite(grad)) if is_nan_inf.any(): # print("[GRAD-HOOK]: Time = %s, Param Name = %s, Detected Nan/Inf" % (get_time_string(), name_of_param)) grad[is_nan_inf] = 0.0 return grad return None if not args["no_autocast"] and args["backward_hook"]: for name, param in model.named_parameters(): param.register_hook(hook) dino_center = None discriminator_dino_center = None total_steps = args["total_steps"] steps_done = 0 step = 0 start_time = time.time() while steps_done < total_steps: random.seed(step) len_proba = random.random() if len_proba < 0.5: batch = dataloader128() elif len_proba < 0.6: batch = dataloader256() else: batch = dataloader() epoch_128 = dataloader128.epoch epoch_256 = dataloader256.epoch epoch_512 = dataloader.epoch # batch = None # if len_proba < 0.9: # batches = [dataloader128() for _ in range(4)] # elif len_proba < 0.97: # batches = [dataloader256() for _ in range(2)] # else: # batch = dataloader() # # if batch is None: # keys = batches[0].keys() # batch = dict() # for k in keys: # elems = [b[k] for b in batches] # if isinstance(elems[0], (list, tuple)): # new_elems = [i for e in elems for i in e] # elif isinstance(elems[0], torch.Tensor): # new_elems = torch.cat(elems, 0) # else: # raise TypeError("Expected List or Tensor") # batch[k] = new_elems key = list(batch.keys())[0] bs_size = list(batch[key].size()) batch = {k: v.to(device, non_blocking=True) if hasattr(v, "to") else v for k, v in batch.items()} gen_batch_time = time.time() - start_time teacher_update_w = np.interp(steps_done, [0, args["teacher_warmup_steps"]], [0.95, 0.999]) inner_model = getattr(trainable_model, "module", trainable_model) if hasattr(inner_model, "start_from_proba"): start_from_proba = np.interp(steps_done, [0, args["warmup_steps"], args["warmup_steps"] * 2], [0.0, 0.0, args["start_from_proba"]]) inner_model.start_from_proba = start_from_proba if hasattr(inner_model.backbone.encoder, "sampling_alpha") and args["sampling_alpha"] is not None and args["sampling_alpha"] != 1.0: sampling_alpha = np.interp(steps_done, [0, args["warmup_steps"], args["warmup_steps"] * 2], [1.0, 1.0, args["sampling_alpha"]]) inner_model.backbone.encoder.sampling_alpha = max(sampling_alpha, 0.01) inner_model.sampling_alpha = max(sampling_alpha, 0.01) if args["dino_w"] > 0: dino_w = np.interp(steps_done, [0, args["teacher_warmup_steps"], args["teacher_warmup_steps"] * 2], [0.0, 0.0, args["dino_w"]]) inner_model.dino_w = dino_w lm_temperature = np.interp(steps_done, [0, args["warmup_steps"], args["warmup_steps"] * 2], [args["lm_temperature"], args["lm_temperature"], args["lm_temperature"] + 1.0]) inner_model.lm_temperature = lm_temperature batch_times.append(gen_batch_time) if (steps_done + 1) % save_every_steps == 0 or (args["total_steps"] is not None and (steps_done + 1) >= args["total_steps"]): state_dict = trainable_model.state_dict() if not isinstance(trainable_model, DDP) else trainable_model.module.state_dict() if local_rank == 0: torch.save(state_dict, os.path.join(model_save_dir, model_save_name)) del state_dict clean_memory() barrier() if args["total_steps"] is not None and (steps_done + 1) >= args["total_steps"]: return samples_processed += int(batch[key].size(0)) samples_processed_this_log_iter += int(batch[key].size(0)) inner_args = dict(no_autocast=args["no_autocast"], cpu=args["cpu"]) validation_iter = (step + 1) % log_every_steps == 0 or step == 0 model_start = time.time() if no_sync and (step + 1) % iter_size != 0 and hasattr(trainable_model, "no_sync"): with trainable_model.no_sync(): output = train_inner_loop(inner_args, model, batch, optimizer, scheduler, gradient_clipping, iter_size=iter_size, no_sync=True, validation_iter=validation_iter, dino_center=dino_center, discriminator_dino_center=discriminator_dino_center, freeze_last_layer=steps_done < args["freeze_last_layer"], step=step + 1) model_times.append(time.time() - model_start) else: output = train_inner_loop(inner_args, model, batch, optimizer, scheduler, gradient_clipping, iter_size=iter_size, no_sync=False, validation_iter=validation_iter, dino_center=dino_center, discriminator_dino_center=discriminator_dino_center, freeze_last_layer=steps_done < args["freeze_last_layer"], step=step + 1) optimizer.zero_grad(set_to_none=True) steps_done += 1 model_times.append(time.time() - model_start) step += 1 del batch if dino_w > 0 and (step + 1) % iter_size: student_teacher_param_update(model.student, model.teacher, teacher_update_w, device if args["move_unused_to_cpu"] else None) dino_center = output.pop("dino_center", None) discriminator_dino_center = output.pop("discriminator_dino_center", None) # if dino_w > 0 and (step + 1) % (1 * iter_size) == 0 and args["world_size"] > 1: # if dino_center is not None: # dtype = dino_center.dtype # dino_center = dino_center.type(torch.float64) / args["world_size"] # torch.distributed.all_reduce(dino_center, torch.distributed.ReduceOp.SUM) # dino_center = dino_center.type(dtype) # if discriminator_dino_center is not None: # dtype = discriminator_dino_center.dtype # discriminator_dino_center = discriminator_dino_center.type(torch.float64) / args["world_size"] # torch.distributed.all_reduce(discriminator_dino_center, torch.distributed.ReduceOp.SUM) # discriminator_dino_center = discriminator_dino_center.type(dtype) if (step + 1) % (4 * iter_size) == 0 and hasattr(getattr(trainable_model, "module", trainable_model).backbone, "layer_normalizers") and args["world_size"] > 1: layer_normalizers = getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers if layer_normalizers is not None: dtype = layer_normalizers.dtype layer_normalizers = layer_normalizers.type(torch.float64) torch.distributed.all_reduce(layer_normalizers, torch.distributed.ReduceOp.SUM) layer_normalizers = layer_normalizers / args["world_size"] getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers = layer_normalizers.type(dtype) if (step + 1) % (4 * iter_size) == 0 and hasattr(getattr(trainable_model, "module", trainable_model).backbone, "layer_normalizers_small") and args["world_size"] > 1: layer_normalizers_small = getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers_small if layer_normalizers_small is not None: dtype = layer_normalizers_small.dtype layer_normalizers_small = layer_normalizers_small.type(torch.float64) torch.distributed.all_reduce(layer_normalizers_small, torch.distributed.ReduceOp.SUM) layer_normalizers_small = layer_normalizers_small / args["world_size"] getattr(trainable_model, "module", trainable_model).backbone.layer_normalizers_small = layer_normalizers_small.type(dtype) full_time = time.time() - start_time full_times.append(full_time) if step == 0 and local_rank == 0: print("[Train]: Time = %s, First Batch Training for Rank = %s" % (get_time_string(), rank)) if validation_iter: steps_remaining = total_steps - steps_done # print({k for k, v in output.items() if isinstance(v, torch.Tensor)}) output = {k: float(v) if v else v for k, v in output.items()} samples_per_second = samples_processed_this_log_iter / np.sum(full_times) wandb_log = dict(lr=optimizer.param_groups[0]['lr'], step=step, samples_processed=samples_processed, samples_per_second=samples_per_second, batch_times=np.mean(batch_times), full_times=np.mean(full_times), model_times=np.mean(model_times), steps_remaining=steps_remaining, pct_complete=(100 * steps_done / total_steps), epoch_128=epoch_128, epoch_256=epoch_256, epoch_512=epoch_512, **{k: v for k, v in output.items() if v is not None}) wandb.log(wandb_log) if local_rank == 0: print("[Train]: Time = %s, Rank = %s, steps = %s, samples_processed=%s, batch_size = %s, Details = %s, LR = %s" % (get_time_string(), rank, step, samples_processed, bs_size, output, optimizer.param_groups[0]['lr'])) print("[Train-Timings]: Time = %s, Batch time = %.4f, Full Time = %.4f, Model Time = %.4f, samples_per_second = %s, steps_remaining = %s, pct_complete = %.4f" % ( get_time_string(), np.mean(batch_times), np.mean(full_times), np.mean(model_times), samples_per_second, steps_remaining, (100 * steps_done / total_steps),)) # print("Step = %s, Steps Done = %s, log_every_steps = %s, total_steps = %s, steps_remaining = %s, validation_iter = %s, %s" % (step, steps_done, log_every_steps, total_steps, steps_remaining, validation_iter, (step + 1) % log_every_steps == 0)) batch_times = [] full_times = [] model_times = [] samples_processed_this_log_iter = 0 if args["enable_layer_normalizers_statistics"] and local_rank == 0: backbone = getattr(model.student, "module", model.student).backbone stats = backbone.layer_normalizers inp_stats=backbone.encoder.layer_normalizers_statistics norms = stats[:, 2, 0].tolist() inp_norms = inp_stats[:, 2, 0].tolist() centers = stats[:, 0, 0:8].tolist() inp_centers = inp_stats[:, 0, 0:8].tolist() stds = stats[:, 1, 0:8].tolist() inp_stds = inp_stats[:, 1, 0:8].tolist() dist_stats = backbone.encoder.distance_statistics.tolist() print("Branch Norms = \n", tabulate(pd.DataFrame(norms), tablefmt="psql")) print("Skip Norms = \n", tabulate(pd.DataFrame({"norm": inp_norms, "dist": dist_stats}), tablefmt="psql")) print("Branch centers = \n", tabulate(pd.DataFrame(centers), tablefmt="psql")) print("Skip centers = \n", tabulate(pd.DataFrame(inp_centers), tablefmt="psql")) print("Branch stds = \n", tabulate(pd.DataFrame(stds), tablefmt="psql")) print("Skip stds = \n", tabulate(pd.DataFrame(inp_stds), tablefmt="psql")) # clean_memory() # barrier() del output del bs_size start_time = time.time() print("Time = %s, Finished Training for Rank = %s" % (get_time_string(), rank)) state_dict = trainable_model.state_dict() if not isinstance(trainable_model, DDP) else trainable_model.module.state_dict() if local_rank == 0: torch.save(state_dict, os.path.join(model_save_dir, model_save_name)) del model
def main(config): train_loader, val_loader = get_loader(config) n_data = len(train_loader.dataset) logger.info(f"length of training dataset: {n_data}") n_data = len(val_loader.dataset) logger.info(f"length of validation dataset: {n_data}") model, criterion = build_scene_segmentation(config) model.cuda() criterion.cuda() if config.optimizer == 'sgd': optimizer = torch.optim.SGD(model.parameters(), lr=config.batch_size * dist.get_world_size() / 8 * config.base_learning_rate, momentum=config.momentum, weight_decay=config.weight_decay) elif config.optimizer == 'adam': optimizer = torch.optim.Adam(model.parameters(), lr=config.batch_size * dist.get_world_size() / 8 * config.base_learning_rate, weight_decay=config.weight_decay) elif config.optimizer == 'adamW': optimizer = torch.optim.AdamW(model.parameters(), lr=config.batch_size * dist.get_world_size() / 8 * config.base_learning_rate, weight_decay=config.weight_decay) else: raise NotImplementedError( f"Optimizer {config.optimizer} not supported") scheduler = get_scheduler(optimizer, len(train_loader), config) model = DistributedDataParallel(model, device_ids=[config.local_rank], broadcast_buffers=False) runing_vote_logits = [ np.zeros((config.num_classes, l.shape[0]), dtype=np.float32) for l in val_loader.dataset.sub_clouds_points_labels ] # optionally resume from a checkpoint if config.load_path: assert os.path.isfile(config.load_path) load_checkpoint(config, model, optimizer, scheduler) logger.info("==> checking loaded ckpt") validate('resume', val_loader, model, criterion, runing_vote_logits, config, num_votes=2) # tensorboard if dist.get_rank() == 0: summary_writer = SummaryWriter(log_dir=config.log_dir) else: summary_writer = None # routine for epoch in range(config.start_epoch, config.epochs + 1): train_loader.sampler.set_epoch(epoch) val_loader.sampler.set_epoch(epoch) train_loader.dataset.epoch = epoch - 1 tic = time.time() loss = train(epoch, train_loader, model, criterion, optimizer, scheduler, config) logger.info('epoch {}, total time {:.2f}, lr {:.5f}'.format( epoch, (time.time() - tic), optimizer.param_groups[0]['lr'])) if epoch % config.val_freq == 0: validate(epoch, val_loader, model, criterion, runing_vote_logits, config, num_votes=2) if dist.get_rank() == 0: # save model save_checkpoint(config, epoch, model, optimizer, scheduler) if summary_writer is not None: # tensorboard logger summary_writer.add_scalar('ins_loss', loss, epoch) summary_writer.add_scalar('learning_rate', optimizer.param_groups[0]['lr'], epoch) validate('Last', val_loader, model, criterion, runing_vote_logits, config, num_votes=20)
def run(max_epochs=None, device=None, batch_size=24, max_sequence_length=128, random_sequence_length=False, epoch_size=None, seed=None, data_dir='data', real_dataset='webtext', fake_dataset='xl-1542M-nucleus', token_dropout=None, large=False, learning_rate=2e-5, weight_decay=0, **kwargs): args = locals() rank, world_size = setup_distributed() if device is None: device = f'cuda:{rank}' if torch.cuda.is_available() else 'cpu' print('rank:', rank, 'world_size:', world_size, 'device:', device) import torch.distributed as dist if distributed() and rank > 0: dist.barrier() model_name = 'roberta-large' if large else 'roberta-base' tokenization_utils.logger.setLevel('ERROR') tokenizer = RobertaTokenizer.from_pretrained(model_name) model = RobertaForSequenceClassification.from_pretrained(model_name).to( device) if rank == 0: summary(model) if distributed(): dist.barrier() if world_size > 1: model = DistributedDataParallel(model, [rank], output_device=rank, find_unused_parameters=True) train_loader, validation_loader = load_datasets( data_dir, real_dataset, fake_dataset, tokenizer, batch_size, max_sequence_length, random_sequence_length, epoch_size, token_dropout, seed) optimizer = Adam(model.parameters(), lr=learning_rate, weight_decay=weight_decay) epoch_loop = count(1) if max_epochs is None else range(1, max_epochs + 1) logdir = os.environ.get("OPENAI_LOGDIR", "logs") os.makedirs(logdir, exist_ok=True) from torch.utils.tensorboard import SummaryWriter writer = SummaryWriter(logdir) if rank == 0 else None best_validation_accuracy = 0 for epoch in epoch_loop: if world_size > 1: train_loader.sampler.set_epoch(epoch) validation_loader.sampler.set_epoch(epoch) train_metrics = train(model, optimizer, device, train_loader, f'Epoch {epoch}') validation_metrics = validate(model, device, validation_loader) combined_metrics = _all_reduce_dict( { **validation_metrics, **train_metrics }, device) combined_metrics["train/accuracy"] /= combined_metrics[ "train/epoch_size"] combined_metrics["train/loss"] /= combined_metrics["train/epoch_size"] combined_metrics["validation/accuracy"] /= combined_metrics[ "validation/epoch_size"] combined_metrics["validation/loss"] /= combined_metrics[ "validation/epoch_size"] if rank == 0: for key, value in combined_metrics.items(): writer.add_scalar(key, value, global_step=epoch) if combined_metrics[ "validation/accuracy"] > best_validation_accuracy: best_validation_accuracy = combined_metrics[ "validation/accuracy"] model_to_save = model.module if hasattr(model, 'module') else model torch.save( dict(epoch=epoch, model_state_dict=model_to_save.state_dict(), optimizer_state_dict=optimizer.state_dict(), args=args), os.path.join(logdir, "best-model.pt"))
def __init__(self, model, regime=None, criterion=None, label_smoothing=0, target_forcing='teacher', print_freq=10, eval_freq=1000, save_freq=1000, grad_clip=None, embedding_grad_clip=None, max_tokens=None, chunk_batch=1, duplicates=1, save_info={}, save_path='.', checkpoint_filename='checkpoint%s.pth', keep_checkpoints=5, avg_loss_time=True, distributed=False, local_rank=0, dtype=torch.float, loss_scale=1, device_ids=None, device="cuda"): super(Seq2SeqTrainer, self).__init__() self.model = model self.criterion = criterion or CrossEntropyLoss( ignore_index=PAD, smooth_eps=label_smoothing, reduction='sum', from_logits=False) self.optimizer = OptimRegime( self.model, regime=regime, use_float_copy=dtype == torch.float16) if target_forcing == 'teacher': self.target_forcing = TeacherForcing(batch_first=self.batch_first) else: self.target_forcing = DecodedInputTargets() self.grad_clip = grad_clip self.embedding_grad_clip = embedding_grad_clip self.epoch = 0 self.training_steps = 0 self.save_info = save_info self.device = device self.dtype = dtype self.loss_scale = loss_scale self.max_tokens = max_tokens self.chunk_batch = chunk_batch self.duplicates = duplicates self.print_freq = print_freq self.eval_freq = eval_freq self.perplexity = float('inf') self.device_ids = device_ids self.avg_loss_time = avg_loss_time self.model_with_loss = AddLossModule(self.model, self.criterion) self.distributed = distributed self.local_rank = local_rank if distributed: self.model_with_loss = DistributedDataParallel( self.model_with_loss, device_ids=[local_rank], output_device=local_rank) else: if isinstance(self.device_ids, tuple): self.model_with_loss = DataParallel(self.model_with_loss, self.device_ids, dim=0 if self.batch_first else 1) self.save_path = save_path self.save_freq = save_freq self.checkpoint_filename = checkpoint_filename self.keep_checkpoints = keep_checkpoints + 1 results_file = os.path.join(save_path, 'results') self.results = ResultsLog(results_file, params=save_info.get('config', None)) self.watcher = None self.streams = {}
def __init__(self, opt): super(VideoAttentionModel, self).__init__(opt) if opt['dist']: self.rank = torch.distributed.get_rank() else: self.rank = -1 # non dist training # define network and load pretrained models self.netG = networks.define_G(opt).to(self.device) if opt['dist']: self.netG = DistributedDataParallel(self.netG, device_ids=[torch.cuda.current_device()]) else: self.netG = DataParallel(self.netG) # print network self.print_network() self.load() if self.is_train: train_opt = opt['train'] self.patch_size = opt['datasets']['train']['patch_size'] self.patch_repeat = opt['datasets']['train']['patch_repeat'] self.use_diff = opt['datasets']['train']['use_diff'] self.netG.train() #### loss loss_type = train_opt['pixel_criterion'] if loss_type == 'l1': self.cri_pix = nn.L1Loss(reduction='sum').to(self.device) elif loss_type == 'l2': self.cri_pix = nn.MSELoss(reduction='sum').to(self.device) elif loss_type == 'bce': self.cri_pix = nn.BCEWithLogitsLoss(reduction='sum').to(self.device) elif loss_type == 'cb': self.cri_pix = CharbonnierLoss().to(self.device) else: raise NotImplementedError('Loss type [{:s}] is not recognized.'.format(loss_type)) self.l_pix_w = train_opt['pixel_weight'] #### optimizers wd_G = train_opt['weight_decay_G'] if train_opt['weight_decay_G'] else 0 if train_opt['ft_tsa_only']: normal_params = [] tsa_fusion_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: if 'tsa_fusion' in k: tsa_fusion_params.append(v) else: normal_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) optim_params = [ { # add normal params first 'params': normal_params, 'lr': train_opt['lr_G'] }, { 'params': tsa_fusion_params, 'lr': train_opt['lr_G'] }, ] else: optim_params = [] for k, v in self.netG.named_parameters(): if v.requires_grad: optim_params.append(v) else: if self.rank <= 0: logger.warning('Params [{:s}] will not optimize.'.format(k)) self.optimizer_G = torch.optim.Adam(optim_params, lr=train_opt['lr_G'], weight_decay=wd_G, betas=(train_opt['beta1'], train_opt['beta2'])) self.optimizers.append(self.optimizer_G) #### schedulers if train_opt['lr_scheme'] == 'MultiStepLR': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.MultiStepLR_Restart(optimizer, train_opt['lr_steps'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'], gamma=train_opt['lr_gamma'], clear_state=train_opt['clear_state'])) elif train_opt['lr_scheme'] == 'CosineAnnealingLR_Restart': for optimizer in self.optimizers: self.schedulers.append( lr_scheduler.CosineAnnealingLR_Restart( optimizer, train_opt['T_period'], eta_min=train_opt['eta_min'], restarts=train_opt['restarts'], weights=train_opt['restart_weights'])) else: raise NotImplementedError() self.log_dict = OrderedDict()
def main(rank): env = gym.make("CartPole-v0") observe_dim = 4 action_num = 2 max_episodes = 2000 max_steps = 200 solved_reward = 190 solved_repeat = 5 # initlize distributed world first world = World(world_size=4, rank=rank, name=str(rank), rpc_timeout=20) servers = model_server_helper(model_num=1) apex_group = world.create_rpc_group("apex", ["0", "1", "2", "3"]) if rank in (2, 3): # learner_group.group is the wrapped torch.distributed.ProcessGroup learner_group = world.create_collective_group(ranks=[2, 3]) # wrap the model with DistributedDataParallel # if current process is learner process 2 or 3 q_net = DistributedDataParallel(module=QNet(observe_dim, action_num), process_group=learner_group.group) q_net_t = DistributedDataParallel(module=QNet(observe_dim, action_num), process_group=learner_group.group) else: q_net = QNet(observe_dim, action_num) q_net_t = QNet(observe_dim, action_num) # we may use a smaller batch size to train if we are using # DistributedDataParallel dqn_apex = DQNApex(q_net, q_net_t, t.optim.Adam, nn.MSELoss(reduction='sum'), apex_group, servers, batch_size=50) # synchronize all processes in the group, make sure # distributed buffer has been created on all processes in apex_group apex_group.barrier() # manually control syncing to improve performance dqn_apex.set_sync(False) if rank in (0, 1): # Process 0 and 1 are workers(samplers) # begin training episode, step, reward_fulfilled = 0, 0, 0 smoothed_total_reward = 0 while episode < max_episodes: # sleep to wait for learners keep up sleep(0.1) episode += 1 total_reward = 0 terminal = False step = 0 state = t.tensor(env.reset(), dtype=t.float32).view(1, observe_dim) # manually pull the newest parameters dqn_apex.manual_sync() while not terminal and step <= max_steps: step += 1 with t.no_grad(): old_state = state # agent model inference action = dqn_apex.act_discrete_with_noise( {"state": old_state}) state, reward, terminal, _ = env.step(action.item()) state = t.tensor(state, dtype=t.float32)\ .view(1, observe_dim) total_reward += reward dqn_apex.store_transition({ "state": { "state": old_state }, "action": { "action": action }, "next_state": { "state": state }, "reward": reward, "terminal": terminal or step == max_steps }) smoothed_total_reward = (smoothed_total_reward * 0.9 + total_reward * 0.1) logger.info("Process {} Episode {} total reward={:.2f}".format( rank, episode, smoothed_total_reward)) if smoothed_total_reward > solved_reward: reward_fulfilled += 1 if reward_fulfilled >= solved_repeat: logger.info("Environment solved!") # will cause torch RPC to complain # since other processes may have not finished yet. # just for demonstration. exit(0) else: reward_fulfilled = 0 elif rank in (2, 3): # wait for enough samples while dqn_apex.replay_buffer.all_size() < 500: sleep(0.1) while True: dqn_apex.update()
def __init__( self, model: Model, optimizer: torch.optim.Optimizer, data_loader: torch.utils.data.DataLoader, patience: Optional[int] = None, validation_metric: str = "-loss", validation_data_loader: torch.utils.data.DataLoader = None, num_epochs: int = 20, serialization_dir: Optional[str] = None, checkpointer: Checkpointer = None, model_save_interval: float = None, cuda_device: int = -1, grad_norm: Optional[float] = None, grad_clipping: Optional[float] = None, learning_rate_scheduler: Optional[LearningRateScheduler] = None, momentum_scheduler: Optional[MomentumScheduler] = None, tensorboard_writer: TensorboardWriter = None, log_batch_size_period: Optional[int] = None, moving_average: Optional[MovingAverage] = None, distributed: bool = False, local_rank: int = 0, world_size: int = 1, num_gradient_accumulation_steps: int = 1, opt_level: Optional[str] = None, ) -> None: """ A trainer for doing supervised learning. It just takes a labeled dataset and a `DataLoader`, and uses the supplied `Optimizer` to learn the weights for your model over some fixed number of epochs. You can also pass in a validation dataloader and enable early stopping. There are many other bells and whistles as well. # Parameters model : `Model`, required. An AllenNLP model to be optimized. Pytorch Modules can also be optimized if their `forward` method returns a dictionary with a "loss" key, containing a scalar tensor representing the loss function to be optimized. If you are training your model using GPUs, your model should already be on the correct device. (If you use `Trainer.from_params` this will be handled for you.) optimizer : `torch.nn.Optimizer`, required. An instance of a Pytorch Optimizer, instantiated with the parameters of the model to be optimized. data_loader : `DataLoader`, required. A pytorch `DataLoader` containing your `Dataset`, yielding padded indexed batches. patience : Optional[int] > 0, optional (default=None) Number of epochs to be patient before early stopping: the training is stopped after `patience` epochs with no improvement. If given, it must be `> 0`. If None, early stopping is disabled. validation_metric : str, optional (default="loss") Validation metric to measure for whether to stop training using patience and whether to serialize an `is_best` model each epoch. The metric name must be prepended with either "+" or "-", which specifies whether the metric is an increasing or decreasing function. validation_dataloader : `DataLoader`, optional (default=None) A `DataLoader` to use for the validation set. If `None`, then use the training `DataLoader` with the validation data. num_epochs : int, optional (default = 20) Number of training epochs. serialization_dir : str, optional (default=None) Path to directory for saving and loading model files. Models will not be saved if this parameter is not passed. checkpointer : `Checkpointer`, optional (default=None) A `Checkpointer` is responsible for periodically saving model weights. If none is given here, we will construct one with default parameters. model_save_interval : `float`, optional (default=None) If provided, then serialize models every `model_save_interval` seconds within single epochs. In all cases, models are also saved at the end of every epoch if `serialization_dir` is provided. cuda_device : `int`, optional (default = -1) An integer specifying the CUDA device(s) to use for this process. If -1, the CPU is used. Data parallelism is controlled at the allennlp train level, so each trainer will have a single GPU. grad_norm : `float`, optional, (default = None). If provided, gradient norms will be rescaled to have a maximum of this value. grad_clipping : `float`, optional (default = `None`). If provided, gradients will be clipped `during the backward pass` to have an (absolute) maximum of this value. If you are getting `NaNs` in your gradients during training that are not solved by using `grad_norm`, you may need this. learning_rate_scheduler : `LearningRateScheduler`, optional (default = None) If specified, the learning rate will be decayed with respect to this schedule at the end of each epoch (or batch, if the scheduler implements the `step_batch` method). If you use `torch.optim.lr_scheduler.ReduceLROnPlateau`, this will use the `validation_metric` provided to determine if learning has plateaued. To support updating the learning rate on every batch, this can optionally implement `step_batch(batch_num_total)` which updates the learning rate given the batch number. momentum_scheduler : `MomentumScheduler`, optional (default = None) If specified, the momentum will be updated at the end of each batch or epoch according to the schedule. tensorboard_writer : `TensorboardWriter`, optional If this is not provided, we will construct a `TensorboardWriter` with default parameters and use that. log_batch_size_period : `int`, optional, (default = `None`) If defined, how often to log the average batch size. moving_average : `MovingAverage`, optional, (default = None) If provided, we will maintain moving averages for all parameters. During training, we employ a shadow variable for each parameter, which maintains the moving average. During evaluation, we backup the original parameters and assign the moving averages to corresponding parameters. Be careful that when saving the checkpoint, we will save the moving averages of parameters. This is necessary because we want the saved model to perform as well as the validated model if we load it later. But this may cause problems if you restart the training from checkpoint. distributed : `bool`, optional, (default = False) If set, PyTorch's `DistributedDataParallel` is used to train the model in multiple GPUs. This also requires `world_size` to be greater than 1. local_rank : `int`, optional, (default = 0) This is the unique identifier of the `Trainer` in a distributed process group. The GPU device id is used as the rank. world_size : `int`, (default = 1) The number of `Trainer` workers participating in the distributed training. num_gradient_accumulation_steps : `int`, optional, (default = 1) Gradients are accumulated for the given number of steps before doing an optimizer step. This can be useful to accommodate batches that are larger than the RAM size. Refer Thomas Wolf's [post](https://tinyurl.com/y5mv44fw) for details on Gradient Accumulation. opt_level : `str`, optional, (default = `None`) Each opt_level establishes a set of properties that govern Amp’s implementation of pure or mixed precision training. Must be a choice of `"O0"`, `"O1"`, `"O2"`, or `"O3"`. See the Apex [documentation](https://nvidia.github.io/apex/amp.html#opt-levels-and-properties) for more details. If `None`, Amp is not used. Defaults to `None`. """ super().__init__(serialization_dir, cuda_device, distributed, local_rank, world_size) # I am not calling move_to_gpu here, because if the model is # not already on the GPU then the optimizer is going to be wrong. self.model = model self.data_loader = data_loader self._validation_data_loader = validation_data_loader self.optimizer = optimizer if patience is None: # no early stopping if validation_data_loader: logger.warning( "You provided a validation dataset but patience was set to None, " "meaning that early stopping is disabled") elif (not isinstance(patience, int)) or patience <= 0: raise ConfigurationError( '{} is an invalid value for "patience": it must be a positive integer ' "or None (if you want to disable early stopping)".format( patience)) # For tracking is_best_so_far and should_stop_early self._metric_tracker = MetricTracker(patience, validation_metric) # Get rid of + or - self._validation_metric = validation_metric[1:] self._num_epochs = num_epochs if checkpointer is not None: self._checkpointer = checkpointer else: self._checkpointer = Checkpointer(serialization_dir) self._model_save_interval = model_save_interval self._grad_norm = grad_norm self._grad_clipping = grad_clipping self._learning_rate_scheduler = learning_rate_scheduler self._momentum_scheduler = momentum_scheduler self._moving_average = moving_average # We keep the total batch number as an instance variable because it # is used inside a closure for the hook which logs activations in # `_enable_activation_logging`. self._batch_num_total = 0 self._tensorboard = tensorboard_writer or TensorboardWriter( serialization_dir) self._tensorboard.get_batch_num_total = lambda: self._batch_num_total self._tensorboard.enable_activation_logging(self.model) self._log_batch_size_period = log_batch_size_period self._last_log = 0.0 # time of last logging self._num_gradient_accumulation_steps = num_gradient_accumulation_steps # Enable automatic mixed precision training with NVIDIA Apex. self._opt_level = opt_level if self._opt_level is not None: if amp is None: raise ConfigurationError(( "Apex not installed but opt_level was provided. Please install NVIDIA's Apex to enable" " automatic mixed precision (AMP) training. See: https://github.com/NVIDIA/apex." )) self.model, self.optimizer = amp.initialize( self.model, self.optimizer, opt_level=self._opt_level) # Using `DistributedDataParallel`(ddp) brings in a quirk wrt AllenNLP's `Model` interface and its # usage. A `Model` object is wrapped by `ddp`, but assigning the wrapped model to `self.model` # will break the usages such as `Model.get_regularization_penalty`, `Model.get_metrics`, etc. # # Hence a reference to Pytorch's object is maintained in the case of distributed training and in the # normal case, reference to `Model` is retained. This reference is only used in # these places: `model.__call__`, `model.train` and `model.eval`. if self._distributed: self._pytorch_model = DistributedDataParallel( self.model, device_ids=[self.cuda_device], find_unused_parameters=True) else: self._pytorch_model = self.model
def train(lm_dataloader, model, criterion, optimizer, vocab_size, args): model.train() from functools import reduce import operator num_params = reduce(operator.add, (reduce(operator.mul, x.size()) for x in model.parameters())) if model.group: total = torch.Tensor([num_params]) if torch.cuda.is_available(): total = total.cuda() torch.distributed.all_reduce(total, group=model.group) logging.info( f"training model, #prams = {num_params}, group: {model.group.rank()}, grank:" f" {torch.distributed.get_rank()}, sizes {model.group.size()}") torch.distributed.barrier() if model.group.rank() == 0: logging.info(f"total #prams = {total.item()}") else: logging.info(f"training model, #prams = {num_params}") vocab_size = 10000 # FIXME total_loss = 0.0 start_time = time.time() word_counter = 0 optimizer = optimizer(model) def get_first_device(model): if isinstance(model, DDP): model = model.module if not torch.cuda.is_available(): return torch.device("cpu") if model.devices: return model.devices[0] else: return torch.cuda.current_device() def get_last_device(model): if isinstance(model, DDP): model = model.module if not torch.cuda.is_available(): return torch.device("cpu") if model.devices: return model.devices[-1] else: return torch.cuda.current_device() pipe_group = model.group if args.ddp_zero: model = DDP( model, device_ids=[torch.cuda.current_device()], process_group=get_data_parallel_group(), find_unused_parameters=False, ) if pipe_group and pipe_group.rank() != 0 and pipe_group.rank() != ( pipe_group.size() - 1): thing = {"input": torch.zeros(args.batch_size)} class FakeDataset: def __getitem__(self, index): return thing def __len__(self): return len(lm_dataloader) lm_dataloader = FakeDataset() for i, batch in enumerate(lm_dataloader): bi = batch["input"] if args.max_batch and i > args.max_batch: break optimizer.zero_grad() try: if (pipe_group is None or pipe_group.rank() == 0) and not args.ddp_zero: tmp = batch["input"].to(get_first_device(model)) output = model(tmp) else: output = model(batch["input"]) except Exception as e: raise RuntimeError( f"training failed on {torch.distributed.get_rank()}") from e if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: target = batch["target"].to(get_last_device(model)) output = output.to(target.device) loss = criterion(output.view(-1, vocab_size), target.view(-1)) if args.ddp_zero: ddp_group = get_data_parallel_group() torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM, group=ddp_group) loss /= ddp_group.size() loss.backward() del target else: if args.ddp_zero: model.module.back_helper(output) else: model.back_helper(output) del output torch.nn.utils.clip_grad_value_(model.parameters(), 0.05) optimizer.step() if pipe_group is None or pipe_group.rank() == pipe_group.size() - 1: total_loss += loss.item() log_interval = 1 word_counter += batch["ntokens"] if i % log_interval == 0 and i > 0: cur_loss = total_loss / log_interval elapsed = time.time() - start_time print( "| batch {:5d} | wps {:5.2f} | loss {:5.2f} | ppl {:8.2f}". format(i, word_counter / elapsed, cur_loss, math.exp(cur_loss))) word_counter = 0 total_loss = 0 start_time = time.time()
def train(args): # disable logging for processes except 0 on every node if args.local_rank != 0: f = open(os.devnull, "w") sys.stdout = sys.stderr = f elif not os.path.exists(args.dir): # create 40 random image, mask paris for training print( f"generating synthetic data to {args.dir} (this may take a while)") os.makedirs(args.dir) # set random seed to generate same random data for every node np.random.seed(seed=0) for i in range(40): im, seg = create_test_image_3d(128, 128, 128, num_seg_classes=1, channel_dim=-1) n = nib.Nifti1Image(im, np.eye(4)) nib.save(n, os.path.join(args.dir, f"img{i:d}.nii.gz")) n = nib.Nifti1Image(seg, np.eye(4)) nib.save(n, os.path.join(args.dir, f"seg{i:d}.nii.gz")) # initialize the distributed training process, every GPU runs in a process dist.init_process_group(backend="nccl", init_method="env://") images = sorted(glob(os.path.join(args.dir, "img*.nii.gz"))) segs = sorted(glob(os.path.join(args.dir, "seg*.nii.gz"))) train_files = [{"img": img, "seg": seg} for img, seg in zip(images, segs)] # define transforms for image and segmentation train_transforms = Compose([ LoadImaged(keys=["img", "seg"]), AsChannelFirstd(keys=["img", "seg"], channel_dim=-1), ScaleIntensityd(keys="img"), RandCropByPosNegLabeld(keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4), RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]), ToTensord(keys=["img", "seg"]), ]) # partition dataset based on current rank number, every rank trains with its own data # it can avoid duplicated caching content in each rank, but will not do global shuffle before every epoch data_part = partition_dataset( data=train_files, num_partitions=dist.get_world_size(), shuffle=True, even_divisible=True, )[dist.get_rank()] train_ds = SmartCacheDataset( data=data_part, transform=train_transforms, replace_rate=0.2, cache_num= 15, # we suppose to use 2 ranks in this example, every rank has 20 training images num_init_workers=2, num_replace_workers=2, ) # use batch_size=2 to load images and use RandCropByPosNegLabeld to generate 2 x 4 images for network training train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, num_workers=2, pin_memory=True) # create UNet, DiceLoss and Adam optimizer device = torch.device(f"cuda:{args.local_rank}") torch.cuda.set_device(device) model = monai.networks.nets.UNet( dimensions=3, in_channels=1, out_channels=1, channels=(16, 32, 64, 128, 256), strides=(2, 2, 2, 2), num_res_units=2, ).to(device) loss_function = monai.losses.DiceLoss(sigmoid=True).to(device) optimizer = torch.optim.Adam(model.parameters(), 1e-3) # wrap the model with DistributedDataParallel module model = DistributedDataParallel(model, device_ids=[device]) # start a typical PyTorch training epoch_loss_values = list() # start the replacement thread of SmartCache train_ds.start() for epoch in range(5): print("-" * 10) print(f"epoch {epoch + 1}/{5}") model.train() epoch_loss = 0 step = 0 for batch_data in train_loader: step += 1 inputs, labels = batch_data["img"].to( device), batch_data["seg"].to(device) optimizer.zero_grad() outputs = model(inputs) loss = loss_function(outputs, labels) loss.backward() optimizer.step() epoch_loss += loss.item() epoch_len = math.ceil(len(train_ds) / train_loader.batch_size) print(f"{step}/{epoch_len}, train_loss: {loss.item():.4f}") epoch_loss /= step epoch_loss_values.append(epoch_loss) # replace 20% of cache content for next epoch train_ds.update_cache() print(f"epoch {epoch + 1} average loss: {epoch_loss:.4f}") # stop replacement thread of SmartCache train_ds.shutdown() print(f"train completed, epoch losses: {epoch_loss_values}") if dist.get_rank() == 0: # all processes should see same parameters as they all start from same # random parameters and gradients are synchronized in backward passes, # therefore, saving it in one process is sufficient torch.save(model.state_dict(), "final_model.pth") dist.destroy_process_group()
def train(cfg): # Set seeds for determinism torch.manual_seed(cfg.training.seed) torch.cuda.manual_seed_all(cfg.training.seed) np.random.seed(cfg.training.seed) random.seed(cfg.training.seed) main_proc = True device = torch.device("cpu" if cfg.training.no_cuda else "cuda") is_distributed = os.environ.get( "LOCAL_RANK") # If local rank exists, distributed env if is_distributed: # when using NCCL, on failures, surviving nodes will deadlock on NCCL ops # because NCCL uses a spin-lock on the device. Set this env var and # to enable a watchdog thread that will destroy stale NCCL communicators os.environ["NCCL_BLOCKING_WAIT"] = "1" device_id = int(os.environ["LOCAL_RANK"]) torch.cuda.set_device(device_id) print(f"Setting CUDA Device to {device_id}") dist.init_process_group(backend=cfg.training.dist_backend.value) main_proc = device_id == 0 # Main process handles saving of models and reporting if OmegaConf.get_type(cfg.checkpointing) == FileCheckpointConfig: checkpoint_handler = FileCheckpointHandler(cfg=cfg.checkpointing) elif OmegaConf.get_type(cfg.checkpointing) == GCSCheckpointConfig: checkpoint_handler = GCSCheckpointHandler(cfg=cfg.checkpointing) else: raise ValueError("Checkpoint Config has not been specified correctly.") if main_proc and cfg.visualization.visdom: visdom_logger = VisdomLogger(id=cfg.visualization.id, num_epochs=cfg.training.epochs) if main_proc and cfg.visualization.tensorboard: tensorboard_logger = TensorBoardLogger( id=cfg.visualization.id, log_dir=to_absolute_path(cfg.visualization.log_dir), log_params=cfg.visualization.log_params) if cfg.checkpointing.load_auto_checkpoint: latest_checkpoint = checkpoint_handler.find_latest_checkpoint() if latest_checkpoint: cfg.checkpointing.continue_from = latest_checkpoint if cfg.checkpointing.continue_from: # Starting from previous model state = TrainingState.load_state( state_path=to_absolute_path(cfg.checkpointing.continue_from)) model = state.model if cfg.training.finetune: state.init_finetune_states(cfg.training.epochs) if main_proc and cfg.visualization.visdom: # Add previous scores to visdom graph visdom_logger.load_previous_values(state.epoch, state.results) if main_proc and cfg.visualization.tensorboard: # Previous scores to tensorboard logs tensorboard_logger.load_previous_values(state.epoch, state.results) else: # Initialise new model training with open(to_absolute_path(cfg.data.labels_path)) as label_file: labels = json.load(label_file) if OmegaConf.get_type(cfg.model) is BiDirectionalConfig: model = DeepSpeech( rnn_hidden_size=cfg.model.hidden_size, nb_layers=cfg.model.hidden_layers, labels=labels, rnn_type=supported_rnns[cfg.model.rnn_type.value], audio_conf=cfg.data.spect, bidirectional=True) elif OmegaConf.get_type(cfg.model) is UniDirectionalConfig: model = DeepSpeech( rnn_hidden_size=cfg.model.hidden_size, nb_layers=cfg.model.hidden_layers, labels=labels, rnn_type=supported_rnns[cfg.model.rnn_type.value], audio_conf=cfg.data.spect, bidirectional=False, context=cfg.model.lookahead_context) else: raise ValueError("Model Config has not been specified correctly.") state = TrainingState(model=model) state.init_results_tracking(epochs=cfg.training.epochs) # Data setup evaluation_decoder = GreedyDecoder( model.labels) # Decoder used for validation train_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=to_absolute_path( cfg.data.train_manifest), labels=model.labels, normalize=True, augmentation_conf=cfg.data.augmentation) test_dataset = SpectrogramDataset(audio_conf=model.audio_conf, manifest_filepath=to_absolute_path( cfg.data.val_manifest), labels=model.labels, normalize=True) if not is_distributed: train_sampler = DSRandomSampler(dataset=train_dataset, batch_size=cfg.data.batch_size, start_index=state.training_step) else: train_sampler = DSElasticDistributedSampler( dataset=train_dataset, batch_size=cfg.data.batch_size, start_index=state.training_step) train_loader = AudioDataLoader(dataset=train_dataset, num_workers=cfg.data.num_workers, batch_sampler=train_sampler) test_loader = AudioDataLoader(dataset=test_dataset, num_workers=cfg.data.num_workers, batch_size=cfg.data.batch_size) model = model.to(device) parameters = model.parameters() if OmegaConf.get_type(cfg.optim) is SGDConfig: optimizer = torch.optim.SGD(parameters, lr=cfg.optim.learning_rate, momentum=cfg.optim.momentum, nesterov=True, weight_decay=cfg.optim.weight_decay) elif OmegaConf.get_type(cfg.optim) is AdamConfig: optimizer = torch.optim.AdamW(parameters, lr=cfg.optim.learning_rate, betas=cfg.optim.betas, eps=cfg.optim.eps, weight_decay=cfg.optim.weight_decay) else: raise ValueError("Optimizer has not been specified correctly.") model, optimizer = amp.initialize(model, optimizer, enabled=not cfg.training.no_cuda, opt_level=cfg.apex.opt_level, loss_scale=cfg.apex.loss_scale) if state.optim_state is not None: optimizer.load_state_dict(state.optim_state) if state.amp_state is not None: amp.load_state_dict(state.amp_state) # Track states for optimizer/amp state.track_optim_state(optimizer) if not cfg.training.no_cuda: state.track_amp_state(amp) if is_distributed: model = DistributedDataParallel(model, device_ids=[device_id]) print(model) print("Number of parameters: %d" % DeepSpeech.get_param_size(model)) criterion = CTCLoss() batch_time = AverageMeter() data_time = AverageMeter() losses = AverageMeter() for epoch in range(state.epoch, cfg.training.epochs): model.train() end = time.time() start_epoch_time = time.time() state.set_epoch(epoch=epoch) train_sampler.set_epoch(epoch=epoch) train_sampler.reset_training_step(training_step=state.training_step) for i, (data) in enumerate(train_loader, start=state.training_step): state.set_training_step(training_step=i) inputs, targets, input_percentages, target_sizes = data input_sizes = input_percentages.mul_(int(inputs.size(3))).int() # measure data loading time data_time.update(time.time() - end) inputs = inputs.to(device) out, output_sizes = model(inputs, input_sizes) out = out.transpose(0, 1) # TxNxH float_out = out.float() # ensure float32 for loss loss = criterion(float_out, targets, output_sizes, target_sizes).to(device) # loss = loss / inputs.size(0) # average the loss by minibatch loss_value = loss.item() # Check to ensure valid loss was calculated valid_loss, error = check_loss(loss, loss_value) if valid_loss: optimizer.zero_grad() # compute gradient with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), cfg.optim.max_norm) optimizer.step() else: print(error) print('Skipping grad update') loss_value = 0 state.avg_loss += loss_value losses.update(loss_value, inputs.size(0)) # measure elapsed time batch_time.update(time.time() - end) end = time.time() print('Epoch: [{0}][{1}/{2}]\t' 'Time {batch_time.val:.3f} ({batch_time.avg:.3f})\t' 'Data {data_time.val:.3f} ({data_time.avg:.3f})\t' 'Loss {loss.val:.4f} ({loss.avg:.4f})\t'.format( (epoch + 1), (i + 1), len(train_loader), batch_time=batch_time, data_time=data_time, loss=losses)) if main_proc and cfg.checkpointing.checkpoint_per_iteration: checkpoint_handler.save_iter_checkpoint_model(epoch=epoch, i=i, state=state) del loss, out, float_out state.avg_loss = state.avg_loss / len( train_dataset) * cfg.data.batch_size epoch_time = time.time() - start_epoch_time print('Training Summary Epoch: [{0}]\t' 'Time taken (s): {epoch_time:.0f}\t' 'Average Loss {loss:.3f}\t'.format(epoch + 1, epoch_time=epoch_time, loss=state.avg_loss)) time.sleep(15 * 60) with torch.no_grad(): wer, cer, output_data = run_evaluation( test_loader=test_loader, device=device, model=model, decoder=evaluation_decoder, target_decoder=evaluation_decoder) print('Validation Summary Epoch: [{0}]\t' 'Average WER {wer:.3f}\t' 'Average CER {cer:.3f}\t'.format(epoch + 1, wer=wer, cer=cer)) state.add_results(epoch=epoch, loss_result=state.avg_loss, wer_result=wer, cer_result=cer) if main_proc and cfg.visualization.visdom: visdom_logger.update(epoch, state.result_state) if main_proc and cfg.visualization.tensorboard: tensorboard_logger.update(epoch, state.result_state, model.named_parameters()) if main_proc and cfg.checkpointing.checkpoint: # Save epoch checkpoint checkpoint_handler.save_checkpoint_model(epoch=epoch, state=state) # anneal lr for g in optimizer.param_groups: g['lr'] = g['lr'] / cfg.optim.learning_anneal print('Learning rate annealed to: {lr:.6f}'.format(lr=g['lr'])) if main_proc and (state.best_wer is None or state.best_wer > wer): checkpoint_handler.save_best_model(epoch=epoch, state=state) state.set_best_wer(wer) state.reset_avg_loss() state.reset_training_step() # Reset training step for next epoch
class DistTrainer(): r""" 分布式的 Trainer,支持分布式训练和混合精度的训练。具体实现原理请阅读 pytorch 官方文档。 Note: 使用分布式 Trainer 时会同时有多个进程执行训练代码。因此将单进程的训练代码改为多进程之前, 请仔细检查,确保训练代码中的同步和互斥操作能正确执行(如模型保持,打印日志等) """ def __init__(self, train_data, model, optimizer=None, loss=None, callbacks_all=None, callbacks_master=None, batch_size_per_gpu=8, n_epochs=1, num_workers=1, drop_last=False, dev_data=None, metrics=None, metric_key=None, update_every=1, print_every=10, validate_every=-1, save_path=None, device='auto', fp16=False, use_tqdm=True, **kwargs): r""" :param train_data: 训练集, :class:`~fastNLP.DataSet` 类型。 :param nn.modules model: 待训练的模型 :param optimizer: `torch.optim.Optimizer` 优化器。如果为None,则Trainer使用默认的Adam(model.parameters(), lr=4e-3)这个优化器 :param loss: 使用的 :class:`~fastNLP.core.losses.LossBase` 对象。当为None时,默认使用 :class:`~fastNLP.LossInForward` :param list callbacks_all: 用于在train过程中起调节作用的回调函数,作用于所有训练进程中。 可使用的callback参见 :mod:`callback模块 <fastNLP.core.callback>` :param list callbacks_master: 用于在train过程中起调节作用的回调函数,只作用于其中一个进程( Master 进程)。 可使用的callback参见 :mod:`callback模块 <fastNLP.core.callback>` :param int batch_size_per_gpu: 训练时,每个进程的 batch 大小。 :param int n_epochs: 需要优化迭代多少次。 :param num_workers: int, 有多少个线程来进行数据pad处理。 :param drop_last: 如果最后一个batch没有正好为batch_size这么多数据,就扔掉最后一个batch :param dev_data: 用于做验证的DataSet, :class:`~fastNLP.DataSet` 类型。 :param metrics: 验证的评估函数。可以只使用一个 :class:`Metric<fastNLP.core.metrics.MetricBase>` , 也可以使用多个 :class:`Metric<fastNLP.core.metrics.MetricBase>` ,通过列表传入。 如验证时取得了更好的验证结果(如果有多个Metric,以列表中第一个Metric为准),且save_path不为None, 则保存当前模型。Metric种类详见 :mod:`metrics模块 <fastNLP.core.metrics>` 。仅在传入dev_data时有效。 :param str,None metric_key: :class:`Metric<fastNLP.core.metrics.MetricBase>` 有时会有多个指标, 比如 :class:`~fastNLP.core.metrics.SpanFPreRecMetric` 中包含了'f', 'pre', 'rec'。此时需 要指定以哪个指标为准。另外有些指标是越小效果越好,比如语言模型的困惑度,这种情况下,在key前面增加一个'-'来表 明验证时,值越小越好(比如: "-ppl")。仅在传入dev_data时有效。 :param update_every: int, 多少步更新一次梯度。用于希望累计梯度的场景,比如需要128的batch_size, 但是直接设为128 会导致内存不足,通过设置batch_size=32, update_every=4达到目的。当optimizer为None时,该参数无效。 :param int print_every: 多少次反向传播更新tqdm显示的loss; 如果use_tqdm=False, 则多少次反向传播打印loss。 :param int validate_every: 多少个step在验证集上验证一次; 如果为-1,则每个epoch结束验证一次。仅在传入dev_data时有效。 :param str,None save_path: 将模型保存路径,如果路径不存在,将自动创建文件夹。如果为None,则不保存模型。如果dev_data为None,则保存 最后一次迭代的模型。保存的时候不仅保存了参数,还保存了模型结构。即便使用DataParallel,这里也只保存模型。 :param str device: 指定 device,可以是 gpu,cpu 或 auto :param bool fp16: 指定是否使用半精度训练。 :param bool use_tqdm: 是否使用tqdm来显示训练进度; 如果为False,则将loss打印在终端中。 :param kwargs: 支持配置可选参数 bool test_use_tqdm: 在dev上验证的时候是否开启tqdm Sampler test_sampler: 在evaluate的时候使用的sampler int dev_batch_size: 在evaluate时,使用的evaluate的batch大小 bool test_use_fp16: test时使用fp16 bool set_grad_to_none: zero_grad时将grad设为None而不是0 GradScaler gradscaler: 自定义的梯度 scaler """ assert device in ['auto', 'cuda', 'cpu'], "Please set correct device in [auto', 'cuda', 'cpu']" if device == 'auto': device = 'cuda' if torch.cuda.is_available() else 'cpu' # init distributed if device == 'cuda': torch.cuda.set_device(get_local_rank()) self.device = torch.device("cuda", get_local_rank()) else: self.device = torch.device(device) init_logger_dist() self.world_size = dist.get_world_size() self.rank = dist.get_rank() # unique id for each process self.train_data = train_data self.batch_size_per_gpu = int(batch_size_per_gpu) self.n_epochs = int(n_epochs) self.num_data_workers = int(num_workers) self.drop_last = drop_last self.update_every = int(update_every) self.print_every = int(print_every) self.validate_every = int(validate_every) self.save_path = save_path self.losser = _prepare_losser(loss) self.fp16 = fp16 self.local_rank = get_local_rank() self._forward_func = model.forward self.callback_manager = DistCallbackManager( env={"trainer": self}, callbacks_all=callbacks_all, callbacks_master=callbacks_master) self.test_manager = DistCallbackManager(env={'trainer': self}) self.metric_key = metric_key self.use_tqdm = use_tqdm model.to(self.device) # init fp16, must before DataParallel init autocast, GradScaler = _build_fp16_env(dummy=not self.fp16) self.auto_cast = autocast user_grad_scaler = getattr(kwargs, 'gradscaler', None) if user_grad_scaler is not None: assert self.fp16, "must set fp16=True to enable gradscaler" grad_scaler = user_grad_scaler else: grad_scaler = GradScaler() self.grad_scaler = grad_scaler self.set_grad_to_none = getattr(kwargs, 'set_grad_to_none', True) # init DataParallel if parse_version(torch.__version__)>=parse_version('1.1'): self.ddp_model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank, find_unused_parameters=True) else: self.ddp_model = DDP(model, device_ids=[self.local_rank], output_device=self.local_rank) self.model = self.ddp_model.module optimizer = self._get_optimizer(optimizer) self.optimizer = optimizer if isinstance(self.train_data, DataSet): self.sampler = DistributedSampler(self.train_data) self.data_iterator = self._get_data_iter(self.train_data) self.batch_size = self.world_size * self.batch_size_per_gpu self.n_steps = self._get_n_steps() self.dev_data = dev_data self.metrics = metrics self.test_use_tqdm = True self.kwargs = kwargs self.test_use_tqdm = kwargs.get('test_use_tqdm', self.use_tqdm) dev_batch_size = kwargs.get('dev_batch_size', batch_size_per_gpu) # for evaluation, only run eval on master proc if dev_data and metrics: cb = _TesterCallback( dev_data, model, metrics, batch_size=dev_batch_size, num_workers=num_workers, sampler=kwargs.get('test_sampler', None), use_tqdm=self.test_use_tqdm) self.test_manager.add_callback([cb], master=True) # Setup logging # 同步start_time sync_time = torch.tensor(time.time(), dtype=torch.double).to(self.device) dist.broadcast(sync_time, src=0) self.start_time = datetime.fromtimestamp(sync_time.item()).strftime('%Y-%m-%d-%H-%M-%S-%f') # print('sync_time: {}, start_time: {}'.format(sync_time, self.start_time)) if self.save_path: self.cp_save_path = self.save_path else: self.cp_save_path = None # use INFO in the master, WARN for others self.logger = logger self.logger.info("Setup Distributed Trainer") self.logger.warning("Process pid: {}, rank: {}, local rank: {}, device: {}, fp16: {}".format( os.getpid(), self.rank, self.local_rank, self.device, self.fp16)) self.logger.info("Num of processes: {}".format(self.world_size)) self.logger.info("Use device: {}".format(device)) def _maybe_no_sync(self): """ Whenever *samples* contains more than one mini-batch, we want to accumulate gradients locally and only call all-reduce in the last backwards pass. """ i = self.step % self.update_every if ( self.world_size > 1 and hasattr(self.ddp_model, "no_sync") and i != 0 ): return self.ddp_model.no_sync() else: return contextlib.ExitStack() # dummy contextmanager def _get_n_steps(self): return len(self.data_iterator) * self.n_epochs def _get_data_iter(self, dataset): if isinstance(dataset, DataSet): return DataSetIter(dataset=dataset, batch_size=self.batch_size_per_gpu, sampler=self.sampler, num_workers=self.num_data_workers, drop_last=self.drop_last) elif isinstance(dataset, BatchIter): return dataset else: raise TypeError("train_data type {} not support".format(type(dataset))) def _get_optimizer(self, optimizer): if isinstance(optimizer, torch.optim.Optimizer): return optimizer elif isinstance(optimizer, Optimizer): return optimizer.construct_from_pytorch(self.ddp_model.parameters()) elif optimizer is None: return torch.optim.Adam(self.ddp_model.parameters(), lr=4e-3) else: if not (hasattr(optimizer, 'step') and callable(optimizer.step)): raise TypeError("optimizer must have a callable step() function.") else: self.optimizer = optimizer @property def is_master(self): r"""是否是主进程""" return self.rank == 0 def train(self, load_best_model=True, on_exception='auto'): r""" 使用该函数使Trainer开始训练。 :param str on_exception: 在训练过程遭遇exception,并被 :py:class:Callback 的on_exception()处理后,是否继续抛出异常。 支持'ignore','raise', 'auto': 'ignore'将捕获异常,写在Trainer.train()后面的代码将继续运行; 'raise'将异常抛出; 'auto'将ignore以下两种Exception: CallbackException与KeyboardInterrupt, raise其它exception. :return dict: 返回一个字典类型的数据, 内含以下内容:: seconds: float, 表示训练时长 以下三个内容只有在提供了dev_data的情况下会有。 best_eval: Dict of Dict, 表示evaluation的结果。第一层的key为Metric的名称, 第二层的key为具体的Metric best_epoch: int,在第几个epoch取得的最佳值 best_step: int, 在第几个step(batch)更新取得的最佳值 """ try: self.logger.info("###### Training epochs started ######") self.logger.info('Total epochs: %d'% self.n_epochs) self.logger.info('Total steps: %d'% self.n_steps) self.logger.info('Num instances per GPU: %d'% self.batch_size_per_gpu) self.logger.info('Num of steps per update: %d' % self.update_every) self.logger.info('Total batch_size: %d'% (self.batch_size_per_gpu * dist.get_world_size() * self.update_every)) self.logger.info('Total num of samples: %d'% len(self.train_data)) self.logger.info("Num of callbacks for all workers: {}".format( len(self.callback_manager.callbacks_all))) self.logger.info("Num of callbacks for master workers: {}".format( len(self.callback_manager.callbacks_master))) self.logger.info("Callbacks for all workers: {}".format( [repr(cb) for cb in self.callback_manager.callbacks_all])) self.logger.info("Callbacks for master workers: {}".format( [repr(cb) for cb in self.callback_manager.callbacks_master])) start_time = time.time() results = {} if self.n_epochs <= 0: self.logger.info("Training epoch is {}, nothing was done.".format(self.n_epochs)) results['seconds'] = 0. return results try: self.callback_manager.on_train_begin() self._train() self.callback_manager.on_train_end() except BaseException as e: self.callback_manager.on_exception(e) if on_exception == 'auto': if not isinstance(e, (CallbackException, KeyboardInterrupt)): raise e else: self.logger.info('Catch {}, ignored.'.format(e.__class__.__name__)) elif on_exception == 'raise': raise e results['seconds'] = round(time.time() - start_time, 2) self.logger.info("###### Train finished ######") self.logger.info('Total train time: {} seconds.'. format(results['seconds'])) if load_best_model and self.cp_save_path and len(self.test_manager.callbacks): self.load_check_point(self._best_save_name()) finally: pass dist.barrier() return results def _train(self): dist.barrier() if not self.use_tqdm: from .utils import _pseudo_tqdm as inner_tqdm else: inner_tqdm = tqdm self.step = 0 self.epoch = 0 self.pbar = inner_tqdm(total=self.n_steps, postfix='loss:{0:<6.5f}', leave=False, dynamic_ncols=True, disable=not self.is_master) pbar = self.pbar avg_loss = 0 data_iterator = self.data_iterator self.ddp_model.zero_grad() for epoch in range(1, self.n_epochs + 1): self.epoch = epoch pbar.set_description_str(desc="Epoch {}/{}".format(epoch, self.n_epochs)) # early stopping self.callback_manager.on_epoch_begin() for batch_x, batch_y in data_iterator: self.step += 1 self.ddp_model.train() _move_dict_value_to_device(batch_x, batch_y, device=self.device) indices = data_iterator.get_batch_indices() # negative sampling; replace unknown; re-weight batch_y self.callback_manager.on_batch_begin(batch_x, batch_y, indices) with self.auto_cast(): prediction = self._data_forward(self.ddp_model, batch_x) # edit prediction self.callback_manager.on_loss_begin(batch_y, prediction) loss = self._compute_loss(prediction, batch_y) avg_loss += loss.detach() # Is loss NaN or inf? requires_grad = False self.callback_manager.on_backward_begin(loss) self.grad_scaler.scale(loss).backward() self.callback_manager.on_backward_end() if self.step % self.update_every == 0: self._update() self.callback_manager.on_step_end() if self.step % self.print_every == 0: avg_loss = float(avg_loss) / self.print_every print_output = "loss:{:<6.5f}".format(avg_loss) pbar.update(self.print_every) pbar.set_postfix_str(print_output) avg_loss = 0 self.callback_manager.on_batch_end() if (self.validate_every > 0 and self.step % self.validate_every == 0) and len(self.test_manager.callbacks): self._do_validation() # ================= mini-batch end ==================== # if self.validate_every < 0 and len(self.test_manager.callbacks): self._do_validation() # lr decay; early stopping self.callback_manager.on_epoch_end() # =============== epochs end =================== # pbar.close() self.pbar = None # ============ tqdm end ============== # def _clear_grad_opt(self, optimizer): if self.set_grad_to_none: for group in optimizer.param_groups: for p in group['params']: if p.grad is not None: p.grad = None else: optimizer.zero_grad() def _update(self): r"""Perform weight update on a model. """ self.grad_scaler.step(self.optimizer) self.grad_scaler.update() self._clear_grad_opt(self.optimizer) def _data_forward(self, network, x): x = _build_args(self._forward_func, **x) y = network(**x) if not isinstance(y, dict): raise TypeError( f"The return value of {_get_func_signature(self._forward_func)} should be dict, got {type(y)}.") return y def _compute_loss(self, predict, truth): r"""Compute loss given prediction and ground truth. :param predict: prediction dict, produced by model.forward :param truth: ground truth dict, produced by batch_y :return: a scalar """ loss = self.losser(predict, truth) if self.update_every > 1: loss = loss / self.update_every if loss.dim() > 0: loss = loss.mean() return loss def save_check_point(self, name=None, only_params=False): r"""保存当前模型""" # only master save models if name is None: name = 'checkpoint-{}.bin'.format(self.step) os.makedirs(self.cp_save_path, exist_ok=True) path = os.path.join(self.cp_save_path, name) self.logger.info("Save checkpoint to {}".format(path)) model_to_save = self.ddp_model.module if only_params: model_to_save = model_to_save.state_dict() if self.is_master: torch.save(model_to_save, path) def load_check_point(self, name): path = os.path.join(self.cp_save_path, name) self.logger.info('reload best model from %s', path) model_load = torch.load( path, map_location=lambda s, l: default_restore_location(s, "cpu")) if not isinstance(model_load, dict): model_load = model_load.state_dict() self.model.load_state_dict(model_load) def _best_save_name(self, auto_fix=True): best_name = "best_" + "_".join([self.model.__class__.__name__, str(self.metric_key), self.start_time]) return best_name def _do_validation(self): with self.ddp_model.no_sync(): # 因为模型参数不更新,可以关闭同步 self.callback_manager.on_valid_begin() eval_res = self.test_manager.on_valid_begin() eval_res = list(filter(lambda x: x is not None, eval_res)) if len(eval_res): eval_res, is_better = list(zip(*eval_res)) eval_res = eval_res[0] is_better = is_better[0] else: eval_res, is_better = None, None if self.metric_key is None and eval_res is not None: eval_res0 = list(eval_res.values())[0] self.metric_key = list(eval_res0.keys())[0] # logger.info('{}, {}'.format(eval_res, is_better)) # save better model on master node if is_better is not None and self.cp_save_path: if is_better: self.save_check_point(self._best_save_name(), only_params=False) dist.barrier() if not self.is_master and self.metric_key is None: # 主进程自动得到了metric_key,而其它进程没有 prefix = 'best_' + self.model.__class__.__name__ suffix = self.start_time fn_list = os.listdir(self.cp_save_path) fn_list = [fn for fn in fn_list if fn.startswith(prefix) and fn.endswith(suffix)] if len(fn_list) == 1: best_name = fn_list[0] self.metric_key = best_name[len(prefix):-len(suffix)].strip('_') # print('RANK {} metric_key {}'.format(self.rank, self.metric_key)) self.callback_manager.on_valid_end( eval_res, self.metric_key, self.optimizer, is_better) self.ddp_model.train() def close(self): r"""关闭Trainer,销毁进程""" dist.destroy_process_group()
class DeepSpeedPlugin(DDPPlugin): distributed_backend = "deepspeed" DEEPSPEED_ENV_VAR = "PL_DEEPSPEED_CONFIG_PATH" def __init__(self, zero_optimization: bool = True, stage: int = 2, cpu_offload: bool = False, contiguous_gradients: bool = True, overlap_comm: bool = True, allgather_partitions: bool = True, reduce_scatter: bool = True, allgather_bucket_size: int = 2e8, reduce_bucket_size: int = 2e8, zero_allow_untested_optimizer: bool = True, config: Optional[Union[Path, str, dict]] = None, logging_level: int = logging.WARN, num_nodes: int = 1, parallel_devices: Optional[List[torch.device]] = None, cluster_environment: Optional[ClusterEnvironment] = None, loss_scale: float = 0, initial_scale_power: int = 32, loss_scale_window: int = 1000, hysteresis: int = 2, min_loss_scale: int = 1) -> None: """ Provides capabilities to run training using the DeepSpeed library, with training optimizations for large billion parameter models. `For more information: https://www.deepspeed.ai/`. .. warning:: ``DeepSpeedPlugin`` is in beta and subject to change. Defaults have been set to enable ZeRO-Offload and some have been taken from the link below. These defaults have been set generally, but may require tuning for optimum performance based on your model size. `For more information: https://www.deepspeed.ai/docs/config-json/#zero-optimizations-for-fp16-training`. Arguments: zero_optimization: Enable ZeRO optimization. This is only compatible with precision=16. (default: True) stage: Different stages of the ZeRO Optimizer. 0 is disabled, 1 is optimizer state partitioning, 2 is optimizer+gradient state partitioning (default: 2) cpu_offload: Enable offloading optimizer memory and computation to CPU contiguous_gradients: Copies gradients to a continuous buffer as they are produced. Avoids memory fragmentation during backwards. Useful when training large models. (default: True) overlap_comm: Overlap the reduction (synchronization) of gradients with the backwards computation. This is a speed optimization when training across multiple GPUs/machines. (default: True) allgather_partitions: All gather updated parameters at the end of training step, instead of using a series of broadcast collectives (default: True) reduce_scatter: Use reduce/scatter instead of allreduce to average gradients (default:True) allgather_bucket_size: Number of elements to allgather at once. Used to limit the memory required for larger model sizes, with a tradeoff with speed. (default: 2e8) reduce_bucket_size: Number of elements to reduce at once. Used to limit the memory required for larger model sizes, with a tradeoff with speed (default: 2e8) zero_allow_untested_optimizer: Allow untested optimizers to be used with ZeRO. Currently only Adam is a DeepSpeed supported optimizer when using ZeRO (default: True) config: Pass in a deepspeed formatted config dict, or path to a deepspeed config: https://www.deepspeed.ai/docs/config-json. All defaults will be ignored if a config is passed in. (Default: ``None``) logging_level: Set logging level for deepspeed. (Default: ``logging.WARN``) loss_scale: Loss scaling value for FP16 training. 0.0 results in dynamic loss scaling, otherwise static (Default: 0) initial_scale_power: Power of the initial dynamic loss scale value. Loss scale is computed by ``2^initial_scale_power`` (Default: 32) loss_scale_window: Window in which to raise/lower the dynamic FP16 loss scaling value (Default: 1000) hysteresis: FP16 Delay shift in Dynamic Loss scaling (Default: 2) min_loss_scale: The minimum FP16 dynamic loss scaling value (Default: 1000) """ if not _DEEPSPEED_AVAILABLE: raise MisconfigurationException( "To use the DeepSpeed plugin, you must have DeepSpeed installed." " pip install deepspeed mpi4py") super().__init__(parallel_devices=parallel_devices, num_nodes=num_nodes, cluster_environment=cluster_environment) self.config = self._load_config(config) if self.config is None: # User has not overridden config, set defaults self.config = self._create_default_config( zero_optimization, zero_allow_untested_optimizer, stage=stage, cpu_offload=cpu_offload, contiguous_gradients=contiguous_gradients, overlap_comm=overlap_comm, allgather_partitions=allgather_partitions, reduce_scatter=reduce_scatter, allgather_bucket_size=allgather_bucket_size, reduce_bucket_size=reduce_bucket_size) self._config_initialized = False deepspeed.utils.logging.logger.setLevel(logging_level) # default FP16 parameters. self.loss_scale = loss_scale self.initial_scale_power = initial_scale_power self.loss_scale_window = loss_scale_window self.hysteresis = hysteresis self.min_loss_scale = min_loss_scale def _load_config(self, config): if config is None and self.DEEPSPEED_ENV_VAR in os.environ: rank_zero_info( f"Loading DeepSpeed config from set {self.DEEPSPEED_ENV_VAR} environment variable" ) config = os.environ[self.DEEPSPEED_ENV_VAR] if isinstance(config, str) or isinstance(config, Path): if not os.path.isfile(config): raise MisconfigurationException( f"You passed in a path to a DeepSpeed config but the path does not exist: {config}" ) with open(config) as f: config = json.load(f) return config def pre_dispatch(self): self.set_world_ranks() self.init_ddp_connection(self.global_rank, self.world_size) self.init_deepspeed() # set warning rank rank_zero_only.rank = self.global_rank # set the ranks and devices self.dist.rank = self.global_rank self.dist.device = self.root_device self.barrier() def init_deepspeed(self): if not self._config_initialized: self._format_config() self._config_initialized = True precision = self.lightning_module.trainer.accelerator.precision model = LightningDeepSpeedModule(pl_module=self.model, precision=precision) if self.lightning_module.trainer and self.lightning_module.trainer.training: self._initialize_deepspeed_train(model) else: self._initialize_deepspeed_inference(model) def _init_scheduler_optimizer(self): optimizers, schedulers, optimizer_frequencies = self.lightning_module.trainer.init_optimizers( self.lightning_module) if len(optimizers) > 1 or len(schedulers) > 1: raise MisconfigurationException( "DeepSpeed currently only supports single optimizer, single optional scheduler." ) scheduler = schedulers[0]['scheduler'] if len( schedulers) == 1 else None optimizer = optimizers[0] return optimizer, scheduler, optimizer_frequencies def _initialize_deepspeed_train(self, model): optimizer, lightning_scheduler, optimizer_frequencies = None, None, None if "optimizer" not in self.config: rank_zero_info( "You have not specified an optimizer or scheduler within the DeepSpeed config." "Using `configure_optimizers` to define optimizer and scheduler." ) optimizer, lightning_scheduler, optimizer_frequencies = self._init_scheduler_optimizer( ) model_parameters = filter(lambda p: p.requires_grad, self.model.parameters()) model, optimizer, _, lr_scheduler = deepspeed.initialize( args=SimpleNamespace(local_rank=self.local_rank), model=model, model_parameters=model_parameters, optimizer=optimizer, lr_scheduler=lightning_scheduler, config_params=self.config, ) # set optimizer for save/load, but deepspeed manages the specific optimizer logic self.lightning_module.trainer.optimizers = [optimizer] self.model = model def _initialize_deepspeed_inference(self, model): # move the model to the correct device self.model_to_device() self.pre_configure_ddp() self.model = DistributedDataParallel( model, device_ids=self.determine_ddp_device_ids(), **self._ddp_kwargs, ) def configure_scheduler(self, lr_scheduler): scheduler = _get_default_scheduler_config() scheduler["scheduler"] = lr_scheduler return [scheduler] @property def lightning_module(self): # the model may not be wrapped with DeepEngine & LightningDeepSpeedModule if calling this too early module = getattr(self.model, "module", self.model) return module.module if isinstance( module, LightningDeepSpeedModule) else module @property def distributed_sampler_kwargs(self): distributed_sampler_kwargs = dict(num_replicas=self.world_size, rank=self.global_rank) return distributed_sampler_kwargs def init_optimizers(self, trainer, model: LightningModule) -> Tuple[List, List, List]: # Skip initializing optimizers here as DeepSpeed handles optimizers via config. # User may have specified config options instead in configure_optimizers, but this is handled # via `_initialize_deepspeed_train` return [], [], [] # empty optimizers, schedulers and frequencies def optimizer_step(self, optimizer: torch.optim.Optimizer, lambda_closure: Callable, **kwargs): # note: We rely on the deepspeed engine to carry out the step rather than the optimizer. # internally, the engine has a reference to the optimizer already. self.model.step(**kwargs) def _format_config(self): if self.config is None: raise MisconfigurationException( "To use DeepSpeed you must pass in a DeepSpeed config dict, or a path to a JSON config." " See: https://pytorch-lightning.readthedocs.io/en/latest/advanced/multi_gpu.html#deepspeed" ) self._format_batch_size_and_grad_accum_config() self._format_precision_config() def _format_batch_size_and_grad_accum_config(self): if "gradient_accumulation_steps" in self.config: raise MisconfigurationException( "Within the DeepSpeed config, do not set gradient_accumulation_steps" " as this will be set via accumulate_grad_batches=x argument passed via the Lightning Trainer." ) if "train_micro_batch_size_per_gpu" not in self.config: # train_micro_batch_size_per_gpu is used for throughput logging purposes # by default we use the batch size of the loader which may be incorrect if a batch sampler is passed batch_size = self.lightning_module.train_dataloader().batch_size self.config["train_micro_batch_size_per_gpu"] = batch_size self.config[ "gradient_accumulation_steps"] = self.lightning_module.trainer.accumulate_grad_batches if "gradient_clipping" not in self.config: self.config[ "gradient_clipping"] = self.lightning_module.trainer.gradient_clip_val def _format_precision_config(self): amp_type = self.lightning_module.trainer.accelerator_connector.amp_type amp_level = self.lightning_module.trainer.accelerator_connector.amp_level precision = self.lightning_module.trainer.accelerator_connector.precision if precision == 16: if "fp16" not in self.config and amp_type == AMPType.NATIVE: # FP16 is a DeepSpeed standalone AMP implementation rank_zero_info("Enabling DeepSpeed FP16.") self.config["fp16"] = { "enabled": True, "loss_scale": self.loss_scale, "initial_scale_power": self.initial_scale_power, "loss_scale_window": self.loss_scale_window, "hysteresis": self.hysteresis, "min_loss_scale": self.min_loss_scale } elif "amp" not in self.config and amp_type == AMPType.APEX: rank_zero_only("Enabling DeepSpeed APEX Implementation.") self.config["amp"] = { "enabled": True, "opt_level": amp_level, } if "zero_optimization" in self.config and not ("amp" in self.config or "fp16" in self.config): raise MisconfigurationException( "To use DeepSpeed ZeRO Optimization, you must set precision=16." ) def _create_default_config(self, zero_optimization: bool, zero_allow_untested_optimizer: bool, **zero_kwargs) -> Dict: if zero_optimization: return { "zero_allow_untested_optimizer": zero_allow_untested_optimizer, "zero_optimization": zero_kwargs } return {}