Пример #1
0
def test_scaler_cpu_offload_breaks():
    device = torch.device("cuda")
    torch.cuda.set_device(0)

    # Random port in case the next test run quickly, same port would cause conflict.
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = str(random.randint(2000, 3000))
    torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

    try:
        scaler = ShardedGradScaler()
        model = FullyShardedDataParallel(nn.Linear(5, 5),
                                         cpu_offload=True,
                                         mixed_precision=True)
        optim = torch.optim.SGD(model.parameters(), lr=1e-3)

        input = torch.rand((1, 5), dtype=torch.float).to(device)
        optim.zero_grad()
        with autocast():
            output = model(input)
            loss = F.mse_loss(input, output)

        scaler.scale(loss).backward()
        # TODO (Min): Need to fix. Details in issue #421.
        with pytest.raises(RuntimeError):
            scaler.step(optim)
            scaler.update()

    finally:
        # Clean-up is important or the next test in this file may fail to init the PG.
        torch.distributed.destroy_process_group()
        del os.environ["MASTER_ADDR"]
        del os.environ["MASTER_PORT"]
Пример #2
0
    def _train_for_several_steps(model,
                                 num_steps,
                                 autocast,
                                 lr=0.01,
                                 norm_type=None):
        model_device = next(model.parameters()).device
        # use SGD with momentum instead of Adam, since Adam is scale invariant
        # and this makes it bad for tests

        optim = torch.optim.SGD(params=model.parameters(), lr=lr, momentum=0.9)
        scaler = ShardedGradScaler()
        for _ in range(num_steps):
            optim.zero_grad()
            with torch.cuda.amp.autocast(enabled=autocast):
                # Inputs always cuda regardless of move_grads_cpu, or model.device
                input = model.module.get_input(torch.device("cuda"))
                output = model(*input)
                loss = model.module.get_loss(input, output).to(model_device)
            loss = scaler.scale(loss)
            assert loss.dtype == torch.float32
            model.module.run_backward(loss)
            if norm_type is not None:
                clip_norm = 0.3
                if isinstance(model, FullyShardedDataParallel):
                    model.clip_grad_norm_(clip_norm, norm_type)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   clip_norm, norm_type)
            scaler.step(optim)
            scaler.update()
        if hasattr(model, "assert_idle"):
            model.assert_idle()
        if isinstance(model, FullyShardedDataParallel):
            model.assert_state(TrainingState.IDLE)
        return loss.detach()
Пример #3
0
class SharedDataParallelFairScaleAMPEngine(SharedDataParallelFairScaleEngine):
    """Distributed FairScale MultiGPU training device engine.

    Args:
        address: address to use for backend.
        port: port to use for backend.
        sync_bn: boolean flag for batchnorm synchonization during disributed training.
            if True, applies PyTorch `convert_sync_batchnorm`_ to the model for native torch
            distributed only. Default, False.
        ddp_kwargs: parameters for `fairscale.nn.data_parallel.ShardedDataParallel`.
            Docs for `fairscale.nn.ShardedDataParallel`:
            https://fairscale.readthedocs.io/en/latest/api/nn/sharded_ddp.html
        process_group_kwargs: parameters for `torch.distributed.init_process_group`.
            More info here:
            https://pytorch.org/docs/stable/distributed.html#torch.distributed.init_process_group
        scaler_kwargs: parameters for `fairscale.optim.grad_scaler.ShardedGradScaler`.
            Possible parameters:
            https://fairscale.readthedocs.io/en/latest/api/index.html

    Examples:

    .. code-block:: python

        from catalyst import dl

        runner = dl.SupervisedRunner()
        runner.train(
            engine=dl.SharedDataParallelFairScaleAMPEngine(),
            ...
        )

    .. code-block:: python

        from catalyst import dl

        class MyRunner(dl.IRunner):
            # ...
            def get_engine(self):
                return dl.SharedDataParallelFairScaleAMPEngine(
                    address="0.0.0.0",
                    port=23234,
                    ddp_kwargs={"find_unused_parameters": False},
                    process_group_kwargs={"port": 12345},
                    scaler_kwargs={"growth_factor": 1.5}
                )
            # ...

    .. code-block:: yaml

        args:
            logs: ...

        model:
            _target_: ...
            ...

        engine:
            _target_: SharedDataParallelFairScaleAMPEngine
            address: 0.0.0.0
            port: 23234
            ddp_kwargs:
                find_unused_parameters: false
            process_group_kwargs:
                port: 12345
            scaler_kwargs:
                growth_factor: 1.5

        stages:
            ...

    .. _convert_sync_batchnorm:
        https://pytorch.org/docs/stable/generated/torch.nn.SyncBatchNorm.html#
        torch.nn.SyncBatchNorm.convert_sync_batchnorm
    """
    def __init__(
        self,
        address: str = None,
        port: Union[str, int] = None,
        ddp_kwargs: Dict[str, Any] = None,
        process_group_kwargs: Dict[str, Any] = None,
        scaler_kwargs: Dict[str, Any] = None,
    ):
        """Init."""
        super().__init__(
            address=address,
            port=port,
            ddp_kwargs=ddp_kwargs,
            process_group_kwargs=process_group_kwargs,
        )
        # @TODO: should we support scaler for each optimizer?
        if scaler_kwargs is None:
            scaler_kwargs = {}
        self.scaler_kwargs = scaler_kwargs
        self.scaler = ShardedGradScaler(**self.scaler_kwargs)

    def zero_grad(self, loss, model, optimizer) -> None:
        """Abstraction over ``model.zero_grad()`` step."""
        optimizer.zero_grad()

    def backward_loss(self, loss, model, optimizer) -> None:
        """Abstraction over ``loss.backward()`` step."""
        self.scaler.scale(loss).backward()

    def optimizer_step(self, loss, model, optimizer) -> None:
        """Abstraction over ``optimizer.step()`` step."""
        self.scaler.step(optimizer)
        self.scaler.update()

    def autocast(self):
        """AMP context"""
        return amp.autocast()
Пример #4
0
def train(args, *, tbl):
    cfg, tokenizer, _, _ = nlp.models.bert.get_pretrained_bert(args.model_name, load_backbone=False,
                                                               load_mlm=False)
    cfg = nlp.torch.models.bert.BertModel.get_cfg().clone_merge(cfg)
    model = nlp.torch.models.bert.QTBertForPretrain(cfg)
    model.to(args.device)

    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        parameters_option(args.start_step, model, args, 'Loading')
    else:
        model.apply(nlp.torch.models.bert.init_weights)

    writer = None
    if args.local_rank in (-1, 0):
        writer = SummaryWriter(log_dir=os.path.join(args.ckpt_dir, 'tensorboard'))

    # pin_memory=False due to lack of https://github.com/pytorch/pytorch/commit/54ce171f16c8859f829dde09f87c364c8a6b4130
    sampler = RandomSampler(tbl) if args.local_rank == -1 else DistributedSampler(
        tbl, seed=args.seed)
    # batch_size // 2 for QuickThought
    train_dataloader = DataLoader(np.arange(len(tbl)), sampler=sampler,
                                  collate_fn=functools.partial(collate_fn, args=args, tbl=tbl),
                                  batch_size=args.batch_size // 2,
                                  num_workers=args.num_dataloader_workers, pin_memory=True)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer_arguments = {"lr": args.lr}
    if get_world_size(args) > 1 and args.ZeRO:
        optimizer = OSS(params=model.parameters(), optim=nlp.torch.optimizers.FusedLANS,
                        **optimizer_arguments)
        model = ShardedDataParallel(model, optimizer)
    elif get_world_size(args) > 1:
        optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters,
                                                   **optimizer_arguments)
        model = DistributedDataParallel(model, device_ids=[args.local_rank],
                                        output_device=args.local_rank, find_unused_parameters=True)
    else:
        optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters,
                                                   **optimizer_arguments)

    save_interval = args.ckpt_interval
    logging.info(f'#Total Training Steps={args.num_steps}, '
                 f'Warmup Steps={args.warmup_ratio * args.num_steps}, '
                 f'Save Interval={save_interval}')
    scheduler = nlp.torch.optimizers.schedules.get_warmup_linear_const_decay_poly_schedule(
        optimizer, total_steps=args.num_steps, warmup_ratio=args.warmup_ratio,
        const_ratio=args.const_ratio)

    if args.start_step:
        logging.info(f'Restart training from {args.start_step}')
        states_option(args.start_step, optimizer, args, 'Loading')

    ce_loss_fn = th.nn.CrossEntropyLoss()
    step_num = args.start_step
    if args.phase2:
        step_num -= args.phase1_num_steps
    running_num_tks, running_grad_norm = 0, 0
    running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0

    train_start_time = time.time()
    tic = time.time()
    model.zero_grad()
    if get_world_size(args) > 1 and args.ZeRO:
        scaler = ShardedGradScaler() if args.fp16 else None
    else:
        scaler = th.cuda.amp.GradScaler() if args.fp16 else None

    train_iter = repeat(train_dataloader, set_epoch=args.local_rank != -1)
    while step_num < args.num_steps:
        step_num += 1
        for accum_step in range(args.num_accumulated):
            (input_id, segment_id, valid_length, mlm_positions, mlm_labels) = next(train_iter)
            (input_id, segment_id, valid_length, mlm_positions,
             mlm_labels) = (arr.to(args.device) for arr in next(train_iter))

            model.train()
            accumulation = ((accum_step + 1) % args.num_accumulated != 0)
            with model.no_sync() if get_world_size(args) > 1 and accumulation else suppress():
                with th.cuda.amp.autocast(enabled=args.fp16):
                    _, pooled_out, mlm_scores, qt_similarity = model(input_id, segment_id,
                                                                     valid_length, mlm_positions)
                    mlm_loss = ce_loss_fn(mlm_scores, mlm_labels)
                    qt_label = th.arange(len(input_id) // 2, device=args.device)
                    qt_loss = ce_loss_fn(qt_similarity, qt_label)
                    loss = mlm_loss + qt_loss
                if args.num_accumulated > 1:
                    loss = loss / args.num_accumulated
                if args.fp16:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                with th.no_grad():
                    qt_acc = (qt_similarity.argmax(dim=1) == qt_label).sum() / (len(input_id) // 2)
                    mlm_acc = (mlm_scores.argmax(dim=1) == mlm_labels).sum() / len(mlm_labels)

            # Gather information from all workers for accurate statistics
            reduced_num_tokens = valid_length.sum()
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_num_tokens)
            reduced_num_mlm_tokens = th.tensor(len(mlm_labels), device=args.device)
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_num_mlm_tokens)
            reduced_loss_mlm = mlm_loss.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_loss_mlm)
            reduced_acc_mlm = mlm_acc.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_acc_mlm)
            reduced_bs = th.tensor(len(input_id), device=args.device)
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_bs)
            reduced_loss_qt = qt_loss.detach().clone() * len(input_id) / reduced_bs
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_loss_qt)
            reduced_acc_qt = qt_acc.detach().clone() * len(input_id) / reduced_bs
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_acc_qt)

            running_num_tks += reduced_num_tokens.item()
            running_mlm_loss += reduced_loss_mlm.item()
            running_mlm_acc += reduced_acc_mlm.item()
            running_qt_loss += reduced_loss_qt.item()
            running_qt_acc += reduced_acc_qt.item()

            if not accumulation:
                if args.fp16:
                    scaler.unscale_(optimizer)  # unscale for gradient clipping
                if get_world_size(args) > 1 and args.ZeRO:
                    total_norm = optimizer.clip_grad_norm(args.max_grad_norm)
                else:
                    total_norm = th.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    if get_world_size(args) > 1:
                        distributed.all_reduce(total_norm)
                        total_norm /= get_world_size(args)
                running_grad_norm += total_norm

                if args.fp16:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                with warnings.catch_warnings():
                    # Scheduler may warn if optimizer.step() call is skipped
                    # due to invalid gradients detected by scaler.
                    warnings.simplefilter("ignore", UserWarning)
                    scheduler.step()
                optimizer.zero_grad(set_to_none=True)

        if step_num % args.log_interval == 0:
            toc = time.time()
            wps = running_num_tks / (toc - tic)
            eta = (args.num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600
            interval = args.log_interval * args.num_accumulated
            logging.info(f'[Step {step_num}], LR={scheduler.get_last_lr()[0]:.6f}, '
                         f'Loss MLM/QT={running_mlm_loss / interval:.4f}/'
                         f'{running_qt_loss / interval:.4f}, '
                         f'Acc MLM/QT={running_mlm_acc / interval:.4f}/'
                         f'{running_qt_acc / interval:.4f}, '
                         f'Grad_norm={running_grad_norm / interval:.4f}, '
                         f'Time cost={toc - tic:.2f}, '
                         f'Throughput={wps:.2f} tokens/s, ETA={eta:.2f}h')
            if args.local_rank in (-1, 0):
                writer.add_scalar('Throughput_wps', wps, step_num)
                writer.add_scalar('Loss/MLM', running_mlm_loss / interval, step_num)
                writer.add_scalar('Loss/QT', running_qt_loss / interval, step_num)
                writer.add_scalar('Acc/MLM', running_mlm_acc / interval, step_num)
                writer.add_scalar('Acc/QT', running_qt_acc / interval, step_num)
                writer.add_scalar('LR', scheduler.get_last_lr()[0], step_num)
                writer.add_scalar('Grad_norm', running_grad_norm / interval, step_num)
            running_num_tks, running_grad_norm = 0, 0
            running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0
            tic = time.time()

        # Saving
        if step_num % save_interval == 0 or step_num >= args.num_steps:
            states_option(step_num, optimizer, args, 'Saving')
            if args.local_rank in (0, -1):
                parameters_option(step_num, model, args, 'Saving')

    logging.info('Finish training step: %d', step_num)
    train_end_time = time.time()
    logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time))

    if args.local_rank in (0, -1):
        save_dir = os.path.join(args.ckpt_dir, args.model_name)
        final_save(model, save_dir, tokenizer.vocab, cfg)