Ejemplo n.º 1
0
def train_loop(
    run_id,
    dataset_dir,
    ckpt_run_dir,
    output_dir,
    validation_only=False,
    use_cuda=False,
    light_target=False,
    seed=42,
):
    """Train loop"""
    train_epochs = 10

    math_mode = "fp16"
    rank = dist.get_rank()
    world_size = dist.get_world_size()

    # Dataset arguments
    train_global_batch_size = 2**17  # Global batch size
    max_bs = 2**13  # Max batch size for used hardware
    update_freq = int(max(1, train_global_batch_size // (max_bs * world_size)))
    max_tokens = int(train_global_batch_size // (world_size * update_freq))

    max_source_positions, max_target_positions = 80, 80
    seq_len_multiple = 2
    left_pad = (True, False)
    lang = ("en", "de")

    # specific arch
    model_args = deepcopy(DEFAULT_TRANSFORMER_ARCH)
    model_args["max_source_positions"] = max_source_positions
    model_args["max_target_positions"] = max_target_positions
    model_args["share_all_embeddings"] = True
    model_args["dropout"] = 0.1
    model_args["softmax_type"] = "fast_fill"

    lr = 1.976e-3
    optimizer_args = {
        "lr": lr,
        "eps": 1e-9,
        "betas": (0.9, 0.98),
    }
    scheduler_args = {
        "base_lr": lr,
        "warmup_init_lr": 0.0,
        "warmup_steps": 1000
    }

    loss_scaling_fp16 = {
        "init_scale": 2.0**7,
        "scale_factor": 2,
        "scale_window": 2000,
    }

    criterion_args = {"smoothing": 0.1, "fast_xentropy": True}

    # Horovod stuff
    use_horovod = (math_mode
                   == "fp16") and dist.get_backend() == dist.Backend.MPI
    if use_horovod:
        hvd.init()
        logger.info("Using horovod rank={}".format(hvd.rank()))
        tensor = torch.tensor([1])
        res = hvd.allreduce(tensor, op=hvd.Sum)
        assert res[0] == world_size

    # Load train and validation datasets
    train_set = WMT17Dataset(
        dataset_dir,
        download=True,
        train=True,
        shuffle=True,
        lang=lang,
        left_pad=left_pad,
        max_positions=(max_source_positions, max_target_positions),
        seq_len_multiple=seq_len_multiple,
    )

    validation_set = WMT17Dataset(
        dataset_dir,
        download=False,
        test=True,
        shuffle=True,
        lang=lang,
        left_pad=left_pad,
        max_positions=(max_source_positions, max_target_positions),
        seq_len_multiple=seq_len_multiple,
    )
    src_dict, trg_dict = train_set.src_dict, train_set.trg_dict

    train_batches = get_batches(train_set,
                                max_tokens=max_tokens,
                                bsz_mult=8,
                                shuffle=True,
                                seed=seed)
    val_batches = get_batches(validation_set,
                              max_tokens=max_tokens,
                              bsz_mult=8,
                              shuffle=False)

    train_batches = equalize_batches(train_batches, world_size, seed=seed)

    # Partition by rank
    train_batches = partition_dataset_by_rank(train_batches, rank, world_size)
    val_batches = partition_dataset_by_rank(val_batches, rank, world_size)

    total_train_points = sum(len(b) for b in train_batches)

    validate_every = update_freq * round(
        len(train_batches) * 0.30 / update_freq)  # Validate every 30%

    assert (validate_every % update_freq) == 0
    logger.info("Using {} total train points, {} batches".format(
        total_train_points, len(train_batches)))

    train_loader = DataLoader(
        train_set,
        num_workers=1,
        pin_memory=False,
        collate_fn=train_set.collater,
        batch_sampler=train_batches,
    )

    val_loader = DataLoader(
        validation_set,
        num_workers=1,
        pin_memory=False,
        collate_fn=validation_set.collater,
        batch_sampler=val_batches,
    )

    model = TransformerModel(Arguments(model_args), src_dict, trg_dict)
    criterion = LabelSmoothing(padding_idx=src_dict.pad(), **criterion_args)

    if use_cuda:
        model = model.cuda()
        criterion = criterion.cuda()

    fp_optimizer, optimizer, model = build_optimizer(
        model,
        optimizer_args,
        math_mode=math_mode,
        scaling_args=loss_scaling_fp16,
        use_horovod=use_horovod,
        use_cuda=use_cuda,
    )

    scheduler = SQRTTimeDecayLRWithWarmup(optimizer, **scheduler_args)

    metrics = [BLEUScore(use_raw=True)]
    checkpointer = Checkpointer(ckpt_run_dir=ckpt_run_dir,
                                rank=rank,
                                freq=CheckpointFreq.BEST)

    translator = SequenceGenerator(
        model,
        src_dict=deepcopy(src_dict),
        trg_dict=deepcopy(trg_dict),
        beam_size=4,
        stop_early=True,
        normalize_scores=True,
        len_penalty=0.6,
        sampling=False,
        sampling_topk=-1,
        minlen=1,
    )
    if not validation_only:

        if light_target:
            goal = task4_time_to_bleu_goal(20)
        else:
            goal = task4_time_to_bleu_goal(25)

        num_batches_per_device_train = len(train_loader)
        tracker = Tracker(metrics, run_id, rank, goal=goal)

        dist.barrier()
        tracker.start()

        for epoch in range(0, train_epochs):
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            model.train()
            tracker.train()

            iter_sample_size = 0
            for batch_idx, sample in enumerate(train_loader):
                tracker.batch_start()

                sample = prepare_batch(sample, use_cuda=use_cuda)
                tracker.record_batch_load()

                is_last = batch_idx == len(train_loader)
                update = (batch_idx % update_freq) == update_freq - 1
                init = (batch_idx % update_freq) == 0

                # Clear gradients in the optimizer.
                if init:
                    fp_optimizer.zero_grad()
                    iter_sample_size = 0
                    tracker.record_batch_init()

                # Compute the output
                output = model(**sample["net_input"])
                tracker.record_batch_fwd_pass()

                loss, sample_size = compute_loss(sample, output, criterion)
                loss_per_sample = loss.item() / sample_size
                iter_sample_size += sample_size
                tracker.record_batch_comp_loss()

                # Backprop
                fp_optimizer.backward_loss(loss)
                tracker.record_batch_backprop()

                if update or is_last:
                    # Get batch size over all workers
                    full_bs = get_full_batch_size(iter_sample_size,
                                                  world_size=world_size,
                                                  use_cuda=use_cuda)

                    updated = opt_step(
                        fp_optimizer,
                        tracker,
                        full_bs,
                        update_freq,
                        math_mode,
                        world_size,
                    )

                    if updated:
                        scheduler.step()

                tracker.batch_end()

                record_train_batch_stats(
                    batch_idx=batch_idx,
                    loss=loss_per_sample,
                    output=torch.Tensor([0]),
                    metric_results={},
                    tracker=tracker,
                    num_batches_per_device_train=num_batches_per_device_train,
                )

                if (batch_idx + 1) % validate_every == 0:
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()

                    metric_values, loss = validation_round(
                        val_loader,
                        metrics,
                        criterion,
                        translator,
                        tracker=tracker,
                        use_cuda=use_cuda,
                    )
                    record_validation_stats(metric_values, loss, tracker, rank)
                    if tracker.goal_reached:
                        break

                    model.train()
                    tracker.train()

            if torch.cuda.is_available():
                torch.cuda.empty_cache()

            metric_values, loss = validation_round(
                val_loader,
                metrics,
                criterion,
                translator,
                tracker=tracker,
                use_cuda=use_cuda,
            )
            is_best = record_validation_stats(metric_values, loss, tracker,
                                              rank)
            checkpointer.save(
                tracker,
                model,
                optimizer,
                scheduler,
                tracker.current_epoch,
                is_best,
            )
            tracker.epoch_end()

            if tracker.goal_reached:
                print("Goal Reached!")
                time.sleep(10)
                return
    else:
        cecf = CheckpointsEvaluationControlFlow(
            ckpt_dir=ckpt_run_dir,
            rank=rank,
            world_size=world_size,
            checkpointer=checkpointer,
            model=model,
            epochs=train_epochs,
            loss_function=criterion,
            metrics=metrics,
            use_cuda=use_cuda,
            dtype="fp32",
            max_batch_per_epoch=None,
        )

        train_stats = cecf.evaluate_by_epochs(train_loader)
        with open(os.path.join(output_dir, "train_stats.json"), "w") as f:
            json.dump(train_stats, f)
Ejemplo n.º 2
0
def test_tracker():
    tracker = Tracker([TopKAccuracy(5)], 1, 0)

    assert tracker is not None
Ejemplo n.º 3
0
class TrainValidation(object):
    r"""Train and validate a model.

    Args:
        model (:obj:`torch.nn.Module`): a pytorch model to be trained and validated.
        optimizer (:obj:`torch.optim.Optimizer`): an optimizer for the given model.
        loss_function (:obj:`torch.nn.modules.loss._Loss`): loss function.
        metrics (:obj:`list` of :obj:`mlbench_core.evaluation.pytorch.*`): metrics like TopKAccuracy.
        scheduler (:obj:`mlbench_core.lr_scheduler.pytorch.lr.*`): a scheduler for hyperparameters.
        batch_size (int): The size of batches provided by the dataloader
        train_epochs (int): The number of epochs to train for
        rank (int): The rank of the current workers
        world_size (int): The total number of workers
        run_id (str): The id of the current run
        dtype (str): The datatype to use for the dataloader data
        validate (bool): Whether to run validation on the val dataset. Default: `True`
        schedule_per (str): When to perform a step for the lr scheduler, one of
            `epoch` or `batch`. Default: `epoch`
        checkpoint (:obj:`Checkpointer`): Class that handles checkpointing. Default: `None`
        transform_target_type (str): dtype to transform the target to. Not used. Default: `None`
        average_models (bool): Whether to average models together. Default: `False`
        use_cuda (bool): Whether to train on GPU or not. Default: `False`
        max_batch_per_epoch (int): Maximum number of batches per epoch. Whole dataset
            is used if not specified. Default: `None`
        tracker (:obj:`mlbench_core.utils.Tracker`): Tracker for the controlflow. Default: `None`
    """

    def __init__(self, model, optimizer, loss_function, metrics, scheduler,
                 batch_size, train_epochs, rank, world_size, run_id, dtype,
                 validate=True, schedule_per='epoch', checkpoint=None,
                 transform_target_type=None, average_models=False,
                 use_cuda=False, max_batch_per_epoch=None, tracker=None):
        self.batch_size = batch_size
        self.train_epochs = train_epochs
        self.model = model
        self.optimizer = optimizer
        self.scheduler = scheduler
        self.schedule_per = schedule_per
        self.perform_validation = validate
        self.checkpoint = checkpoint
        self.model = model
        self.optimizer = optimizer
        self.loss_function = loss_function
        self.metrics = metrics
        self.scheduler = scheduler
        self.batch_size = batch_size
        self.rank = rank
        self.run_id = run_id
        self.dtype = dtype
        self.schedule_per = schedule_per
        self.transform_target_type = transform_target_type
        self.use_cuda = use_cuda
        self.max_batch_per_epoch = max_batch_per_epoch
        if tracker:
            self.tracker = tracker
        else:
            self.tracker = Tracker(metrics, run_id, rank)

    def _get_dataloader_stats(self, dataloader_train, dataloader_val):
        """ Sets the stats for the supplied dataloaders

        Args:
            dataloader_train (:obj:`torch.utils.data.DataLoader`): The train set
            dataloader_val (:obj:`torch.utils.data.DataLoader`): The validation set
        """
        self.num_batches_per_device_train = len(dataloader_train)
        self.num_batches_per_device_val = len(dataloader_val)

    def run(self, dataloader_train=None, dataloader_val=None,
            dataloader_train_fn=None, dataloader_val_fn=None, resume=False,
            repartition_per_epoch=False):
        """Execute training and (possibly) validation

        `dataloader_train` and `dataloader_train_fn` are mutually exclusive.
        `dataloader_val` and `dataloader_val_fn` are mutually exclusive.

        Args:
            dataloader_train (:obj:`torch.utils.data.DataLoader`): A dataloader for the train set.
                Default: `None`
            dataloader_val (:obj:`torch.utils.data.DataLoader`): A dataloader for the val set.
                Default: `None`
            dataloader_train_fn (:func:`Function`): A function returning a :obj:`torch.utils.data.DataLoader`
                for the train set. Default: `None`
            dataloader_val_fn (:func:`Function`): A function returning a :obj:`torch.utils.data.DataLoader`
                for the val set. Default: `None`
            resume (bool): Whether this is a resume of a previous run or not. Default: `False`
            repartition_per_epoch (bool): Whether to repartition the dataset again every epoch.
                Requires dataloader_train_fn and/or dataloader_val_fn to be set. Default: `False`
        """

        if not dataloader_train_fn and not dataloader_train:
            raise ValueError(
                "One of dataloader_train_fn or dataloader_train must be set")

        if not dataloader_val_fn and not dataloader_val:
            raise ValueError(
                "One of dataloader_val_fn or dataloader_val must be set")

        if dataloader_train_fn:
            dataloader_train = dataloader_train_fn()

        if dataloader_val_fn:
            dataloader_val = dataloader_val_fn()

        self._get_dataloader_stats(dataloader_train, dataloader_val)

        # define some parameters for training.
        logger.info("There are {train_epochs} epochs, {num_batches} "
                    "mini-batches per epoch (batch size: {batch_size})."
                    .format(
                        train_epochs=self.train_epochs,
                        num_batches=self.num_batches_per_device_train,
                        batch_size=self.batch_size))

        # Initialize Tracker or resume from checkpoint
        if resume:
            start_epoch = self.tracker.current_epoch + 1
        else:
            start_epoch = 0

        dist.barrier()
        for epoch in range(start_epoch, self.train_epochs):
            # Per epoch information.
            logger.info("Current epoch : {} : lr={}"
                        .format(epoch, self.scheduler.get_lr()))

            train_round(dataloader_train, self.model, self.optimizer,
                        self.loss_function, self.metrics, self.scheduler,
                        self.dtype, self.schedule_per,
                        self.transform_target_type, self.use_cuda,
                        self.max_batch_per_epoch, self.tracker)

            is_best = False
            if self.perform_validation:
                is_best = validation_round(dataloader_val, self.model,
                                           self.loss_function, self.metrics,
                                           self.run_id, self.rank, self.dtype,
                                           self.transform_target_type,
                                           self.use_cuda,
                                           self.max_batch_per_epoch,
                                           self.tracker)

            if self.checkpoint:
                self.checkpoint.save(self.tracker, self.model,
                                     self.optimizer, self.scheduler,
                                     self.tracker.current_epoch, is_best)

            # Shuffle the dataset across nodes
            if repartition_per_epoch:
                if dataloader_train_fn:
                    dataloader_train = dataloader_train_fn()

                if dataloader_val_fn:
                    dataloader_val = dataloader_val_fn()

                self._get_dataloader_stats(dataloader_train, dataloader_val)

            self.tracker.epoch_end()