def train_loop( run_id, dataset_dir, ckpt_run_dir, output_dir, validation_only=False, use_cuda=False, light_target=False, seed=42, ): """Train loop""" train_epochs = 10 math_mode = "fp16" rank = dist.get_rank() world_size = dist.get_world_size() # Dataset arguments train_global_batch_size = 2**17 # Global batch size max_bs = 2**13 # Max batch size for used hardware update_freq = int(max(1, train_global_batch_size // (max_bs * world_size))) max_tokens = int(train_global_batch_size // (world_size * update_freq)) max_source_positions, max_target_positions = 80, 80 seq_len_multiple = 2 left_pad = (True, False) lang = ("en", "de") # specific arch model_args = deepcopy(DEFAULT_TRANSFORMER_ARCH) model_args["max_source_positions"] = max_source_positions model_args["max_target_positions"] = max_target_positions model_args["share_all_embeddings"] = True model_args["dropout"] = 0.1 model_args["softmax_type"] = "fast_fill" lr = 1.976e-3 optimizer_args = { "lr": lr, "eps": 1e-9, "betas": (0.9, 0.98), } scheduler_args = { "base_lr": lr, "warmup_init_lr": 0.0, "warmup_steps": 1000 } loss_scaling_fp16 = { "init_scale": 2.0**7, "scale_factor": 2, "scale_window": 2000, } criterion_args = {"smoothing": 0.1, "fast_xentropy": True} # Horovod stuff use_horovod = (math_mode == "fp16") and dist.get_backend() == dist.Backend.MPI if use_horovod: hvd.init() logger.info("Using horovod rank={}".format(hvd.rank())) tensor = torch.tensor([1]) res = hvd.allreduce(tensor, op=hvd.Sum) assert res[0] == world_size # Load train and validation datasets train_set = WMT17Dataset( dataset_dir, download=True, train=True, shuffle=True, lang=lang, left_pad=left_pad, max_positions=(max_source_positions, max_target_positions), seq_len_multiple=seq_len_multiple, ) validation_set = WMT17Dataset( dataset_dir, download=False, test=True, shuffle=True, lang=lang, left_pad=left_pad, max_positions=(max_source_positions, max_target_positions), seq_len_multiple=seq_len_multiple, ) src_dict, trg_dict = train_set.src_dict, train_set.trg_dict train_batches = get_batches(train_set, max_tokens=max_tokens, bsz_mult=8, shuffle=True, seed=seed) val_batches = get_batches(validation_set, max_tokens=max_tokens, bsz_mult=8, shuffle=False) train_batches = equalize_batches(train_batches, world_size, seed=seed) # Partition by rank train_batches = partition_dataset_by_rank(train_batches, rank, world_size) val_batches = partition_dataset_by_rank(val_batches, rank, world_size) total_train_points = sum(len(b) for b in train_batches) validate_every = update_freq * round( len(train_batches) * 0.30 / update_freq) # Validate every 30% assert (validate_every % update_freq) == 0 logger.info("Using {} total train points, {} batches".format( total_train_points, len(train_batches))) train_loader = DataLoader( train_set, num_workers=1, pin_memory=False, collate_fn=train_set.collater, batch_sampler=train_batches, ) val_loader = DataLoader( validation_set, num_workers=1, pin_memory=False, collate_fn=validation_set.collater, batch_sampler=val_batches, ) model = TransformerModel(Arguments(model_args), src_dict, trg_dict) criterion = LabelSmoothing(padding_idx=src_dict.pad(), **criterion_args) if use_cuda: model = model.cuda() criterion = criterion.cuda() fp_optimizer, optimizer, model = build_optimizer( model, optimizer_args, math_mode=math_mode, scaling_args=loss_scaling_fp16, use_horovod=use_horovod, use_cuda=use_cuda, ) scheduler = SQRTTimeDecayLRWithWarmup(optimizer, **scheduler_args) metrics = [BLEUScore(use_raw=True)] checkpointer = Checkpointer(ckpt_run_dir=ckpt_run_dir, rank=rank, freq=CheckpointFreq.BEST) translator = SequenceGenerator( model, src_dict=deepcopy(src_dict), trg_dict=deepcopy(trg_dict), beam_size=4, stop_early=True, normalize_scores=True, len_penalty=0.6, sampling=False, sampling_topk=-1, minlen=1, ) if not validation_only: if light_target: goal = task4_time_to_bleu_goal(20) else: goal = task4_time_to_bleu_goal(25) num_batches_per_device_train = len(train_loader) tracker = Tracker(metrics, run_id, rank, goal=goal) dist.barrier() tracker.start() for epoch in range(0, train_epochs): if torch.cuda.is_available(): torch.cuda.empty_cache() model.train() tracker.train() iter_sample_size = 0 for batch_idx, sample in enumerate(train_loader): tracker.batch_start() sample = prepare_batch(sample, use_cuda=use_cuda) tracker.record_batch_load() is_last = batch_idx == len(train_loader) update = (batch_idx % update_freq) == update_freq - 1 init = (batch_idx % update_freq) == 0 # Clear gradients in the optimizer. if init: fp_optimizer.zero_grad() iter_sample_size = 0 tracker.record_batch_init() # Compute the output output = model(**sample["net_input"]) tracker.record_batch_fwd_pass() loss, sample_size = compute_loss(sample, output, criterion) loss_per_sample = loss.item() / sample_size iter_sample_size += sample_size tracker.record_batch_comp_loss() # Backprop fp_optimizer.backward_loss(loss) tracker.record_batch_backprop() if update or is_last: # Get batch size over all workers full_bs = get_full_batch_size(iter_sample_size, world_size=world_size, use_cuda=use_cuda) updated = opt_step( fp_optimizer, tracker, full_bs, update_freq, math_mode, world_size, ) if updated: scheduler.step() tracker.batch_end() record_train_batch_stats( batch_idx=batch_idx, loss=loss_per_sample, output=torch.Tensor([0]), metric_results={}, tracker=tracker, num_batches_per_device_train=num_batches_per_device_train, ) if (batch_idx + 1) % validate_every == 0: if torch.cuda.is_available(): torch.cuda.empty_cache() metric_values, loss = validation_round( val_loader, metrics, criterion, translator, tracker=tracker, use_cuda=use_cuda, ) record_validation_stats(metric_values, loss, tracker, rank) if tracker.goal_reached: break model.train() tracker.train() if torch.cuda.is_available(): torch.cuda.empty_cache() metric_values, loss = validation_round( val_loader, metrics, criterion, translator, tracker=tracker, use_cuda=use_cuda, ) is_best = record_validation_stats(metric_values, loss, tracker, rank) checkpointer.save( tracker, model, optimizer, scheduler, tracker.current_epoch, is_best, ) tracker.epoch_end() if tracker.goal_reached: print("Goal Reached!") time.sleep(10) return else: cecf = CheckpointsEvaluationControlFlow( ckpt_dir=ckpt_run_dir, rank=rank, world_size=world_size, checkpointer=checkpointer, model=model, epochs=train_epochs, loss_function=criterion, metrics=metrics, use_cuda=use_cuda, dtype="fp32", max_batch_per_epoch=None, ) train_stats = cecf.evaluate_by_epochs(train_loader) with open(os.path.join(output_dir, "train_stats.json"), "w") as f: json.dump(train_stats, f)
def test_tracker(): tracker = Tracker([TopKAccuracy(5)], 1, 0) assert tracker is not None
class TrainValidation(object): r"""Train and validate a model. Args: model (:obj:`torch.nn.Module`): a pytorch model to be trained and validated. optimizer (:obj:`torch.optim.Optimizer`): an optimizer for the given model. loss_function (:obj:`torch.nn.modules.loss._Loss`): loss function. metrics (:obj:`list` of :obj:`mlbench_core.evaluation.pytorch.*`): metrics like TopKAccuracy. scheduler (:obj:`mlbench_core.lr_scheduler.pytorch.lr.*`): a scheduler for hyperparameters. batch_size (int): The size of batches provided by the dataloader train_epochs (int): The number of epochs to train for rank (int): The rank of the current workers world_size (int): The total number of workers run_id (str): The id of the current run dtype (str): The datatype to use for the dataloader data validate (bool): Whether to run validation on the val dataset. Default: `True` schedule_per (str): When to perform a step for the lr scheduler, one of `epoch` or `batch`. Default: `epoch` checkpoint (:obj:`Checkpointer`): Class that handles checkpointing. Default: `None` transform_target_type (str): dtype to transform the target to. Not used. Default: `None` average_models (bool): Whether to average models together. Default: `False` use_cuda (bool): Whether to train on GPU or not. Default: `False` max_batch_per_epoch (int): Maximum number of batches per epoch. Whole dataset is used if not specified. Default: `None` tracker (:obj:`mlbench_core.utils.Tracker`): Tracker for the controlflow. Default: `None` """ def __init__(self, model, optimizer, loss_function, metrics, scheduler, batch_size, train_epochs, rank, world_size, run_id, dtype, validate=True, schedule_per='epoch', checkpoint=None, transform_target_type=None, average_models=False, use_cuda=False, max_batch_per_epoch=None, tracker=None): self.batch_size = batch_size self.train_epochs = train_epochs self.model = model self.optimizer = optimizer self.scheduler = scheduler self.schedule_per = schedule_per self.perform_validation = validate self.checkpoint = checkpoint self.model = model self.optimizer = optimizer self.loss_function = loss_function self.metrics = metrics self.scheduler = scheduler self.batch_size = batch_size self.rank = rank self.run_id = run_id self.dtype = dtype self.schedule_per = schedule_per self.transform_target_type = transform_target_type self.use_cuda = use_cuda self.max_batch_per_epoch = max_batch_per_epoch if tracker: self.tracker = tracker else: self.tracker = Tracker(metrics, run_id, rank) def _get_dataloader_stats(self, dataloader_train, dataloader_val): """ Sets the stats for the supplied dataloaders Args: dataloader_train (:obj:`torch.utils.data.DataLoader`): The train set dataloader_val (:obj:`torch.utils.data.DataLoader`): The validation set """ self.num_batches_per_device_train = len(dataloader_train) self.num_batches_per_device_val = len(dataloader_val) def run(self, dataloader_train=None, dataloader_val=None, dataloader_train_fn=None, dataloader_val_fn=None, resume=False, repartition_per_epoch=False): """Execute training and (possibly) validation `dataloader_train` and `dataloader_train_fn` are mutually exclusive. `dataloader_val` and `dataloader_val_fn` are mutually exclusive. Args: dataloader_train (:obj:`torch.utils.data.DataLoader`): A dataloader for the train set. Default: `None` dataloader_val (:obj:`torch.utils.data.DataLoader`): A dataloader for the val set. Default: `None` dataloader_train_fn (:func:`Function`): A function returning a :obj:`torch.utils.data.DataLoader` for the train set. Default: `None` dataloader_val_fn (:func:`Function`): A function returning a :obj:`torch.utils.data.DataLoader` for the val set. Default: `None` resume (bool): Whether this is a resume of a previous run or not. Default: `False` repartition_per_epoch (bool): Whether to repartition the dataset again every epoch. Requires dataloader_train_fn and/or dataloader_val_fn to be set. Default: `False` """ if not dataloader_train_fn and not dataloader_train: raise ValueError( "One of dataloader_train_fn or dataloader_train must be set") if not dataloader_val_fn and not dataloader_val: raise ValueError( "One of dataloader_val_fn or dataloader_val must be set") if dataloader_train_fn: dataloader_train = dataloader_train_fn() if dataloader_val_fn: dataloader_val = dataloader_val_fn() self._get_dataloader_stats(dataloader_train, dataloader_val) # define some parameters for training. logger.info("There are {train_epochs} epochs, {num_batches} " "mini-batches per epoch (batch size: {batch_size})." .format( train_epochs=self.train_epochs, num_batches=self.num_batches_per_device_train, batch_size=self.batch_size)) # Initialize Tracker or resume from checkpoint if resume: start_epoch = self.tracker.current_epoch + 1 else: start_epoch = 0 dist.barrier() for epoch in range(start_epoch, self.train_epochs): # Per epoch information. logger.info("Current epoch : {} : lr={}" .format(epoch, self.scheduler.get_lr())) train_round(dataloader_train, self.model, self.optimizer, self.loss_function, self.metrics, self.scheduler, self.dtype, self.schedule_per, self.transform_target_type, self.use_cuda, self.max_batch_per_epoch, self.tracker) is_best = False if self.perform_validation: is_best = validation_round(dataloader_val, self.model, self.loss_function, self.metrics, self.run_id, self.rank, self.dtype, self.transform_target_type, self.use_cuda, self.max_batch_per_epoch, self.tracker) if self.checkpoint: self.checkpoint.save(self.tracker, self.model, self.optimizer, self.scheduler, self.tracker.current_epoch, is_best) # Shuffle the dataset across nodes if repartition_per_epoch: if dataloader_train_fn: dataloader_train = dataloader_train_fn() if dataloader_val_fn: dataloader_val = dataloader_val_fn() self._get_dataloader_stats(dataloader_train, dataloader_val) self.tracker.epoch_end()