Esempio n. 1
0
def run_eval_mode(_unused):
    """ Testing eval mode make sure this is no asserts. """
    dist.init_process_group(init_method=f"file://{tempfile.mkstemp()[1]}",
                            backend=dist.Backend.GLOO,
                            rank=0,
                            world_size=1)
    model = Sequential(Linear(2, 3), Linear(3, 4))
    optimizer_params = {"lr": 0.1, "momentum": 0.99}
    ddp = ShardedDataParallel(model,
                              torch.optim.SGD,
                              optimizer_params,
                              1,
                              broadcast_buffers=False)
    optimizer = ddp.optimizer

    ddp.eval()
    for _ in range(5):
        input_tensor = torch.rand((64, 2))
        output = ddp(input_tensor)

    ddp.train()
    try:
        for _ in range(5):
            input_tensor = torch.rand((64, 2))
            output = ddp(input_tensor)
    except RuntimeError:
        pass
    else:
        assert False, "Multiple forward passes on training mode should not pass"

    dist.destroy_process_group()
Esempio n. 2
0
def train(rank, args, model, device, train_loader, num_epochs):
    ##############
    # SETUP
    dist_init(rank, WORLD_SIZE, BACKEND)
    ddp = ShardedDataParallel(
        module=model,
        optimizer=torch.optim.Adadelta,
        optimizer_params={"lr": 1e-4},
        world_size=WORLD_SIZE,
        broadcast_buffers=True,
    )

    ddp.train()
    optimizer = ddp.optimizer
    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

    # Training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()

    loss_fn = nn.CrossEntropyLoss()
    ##############

    model.train()
    measurements = []
    for epoch in range(num_epochs):
        epoch_start = time.monotonic()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)

            def closure():
                model.zero_grad()
                outputs = model(data)
                loss = loss_fn(outputs, target)
                loss /= WORLD_SIZE
                loss.backward()

                # if dist.get_rank() == 0:
                #     print(f"Loss: {loss.item()}")

                ddp.reduce()  # Send the gradients to the appropriate shards
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    print("Total Time:", training_stop - training_start)
def test_train_eval_change():
    # Check that ShardedDDP handles the switch from training to eval properly
    dist.init_process_group(init_method="file://" + tempfile.mkstemp()[1],
                            backend="gloo",
                            rank=0,
                            world_size=1)

    model = _get_mlp()
    model.train()
    optimizer = OSS(params=model.parameters(),
                    optim=torch.optim.SGD,
                    lr=1e-3,
                    momentum=0.99)
    model = ShardedDataParallel(model, optimizer)
    input_tensor = torch.rand((2, 2))
    loss = model(input_tensor).sum()
    loss.backward()  # make sure that the gradients are reduced

    # Wipe the gradients and switch to eval mode
    model.zero_grad()
    model.eval()
    _ = model(input_tensor)
    assert next(model.parameters()).grad is None or torch.norm(
        next(model.parameters()).grad) < 1e-6

    # Get back to training
    model = model.train()
    model(input_tensor).sum().backward()
    assert torch.norm(next(model.parameters()).grad) > 0.0

    dist.destroy_process_group()
Esempio n. 4
0
def train(
    rank: int,
    args: argparse.Namespace,
    backend: str = "gloo",
    optim_type: OptimType = OptimType.vanilla,
    check_regression: bool = True,
):
    logging.basicConfig(
        level=logging.INFO if not args.debug else logging.DEBUG)

    use_multi_tensor = args.multi_tensor_optim and hasattr(
        torch.optim, "_multi_tensor")
    OPTIM = torch.optim._multi_tensor.RMSprop if use_multi_tensor else torch.optim.RMSprop  # type: ignore  # attr is  checked but mypy misses that
    logging.info("Multi tensor optimizer: {}".format(use_multi_tensor))

    # DDP
    dist_init(rank=rank, world_size=args.world_size, backend=backend)

    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device("cpu") if args.cpu else torch.device(rank)
    model, dataloader, loss_fn = get_problem(rank, args.world_size,
                                             args.batch_size, device,
                                             args.model)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None
    model = cast(nn.Module, model)
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else
              ShardedGradScaler()) if args.amp else None

    if optim_type == OptimType.oss_sharded_ddp:
        optimizer = OSS(params=model.parameters(),
                        optim=OPTIM,
                        lr=1e-4,
                        momentum=0.9)
        # Single node run typically, no need for reduce buckets
        model = ShardedDDP(model, optimizer, reduce_buffer_size=0)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore
        optimizer = (OSS(
            params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
                     if optim_type == OptimType.oss_ddp else OPTIM(
                         model.parameters(), lr=1e-4, momentum=0.9))
    optimizer = cast(torch.optim.Optimizer, optimizer)

    # Reset the memory use counter
    if not args.cpu:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)

    # Standard training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0
    need_profiling = args.profile

    for epoch in range(args.epochs):
        n_items = 0
        epoch_runtime = 0.0

        for batch in dataloader:
            if not args.cpu:
                torch.cuda.synchronize(rank)
            batch_start = time.monotonic()

            def closure(data=batch, grad_scaler=None):
                model.zero_grad()
                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("\nbefore:  param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                if grad_scaler is not None:
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
                else:
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
                    loss.backward()

                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("after BW: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                return loss

            def run_closure(closure, scaler, optimizer):
                if scaler is not None:
                    final_loss = closure(
                        grad_scaler=scaler
                    )  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                    return final_loss
                else:
                    return optimizer.step(closure)

            if need_profiling and not args.cpu:
                logging.info("Profiling the run")
                with profiler.profile(
                        use_cuda=True, record_shapes=True,
                        profile_memory=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        final_loss = run_closure(closure, scaler, optimizer)

                prof.export_chrome_trace(
                    f"{optim_type}_trace_rank_{rank}.json")
                need_profiling = False  # only profile once

            else:
                final_loss = run_closure(closure, scaler, optimizer)

            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(
                    next(model.buffers()).norm().item()))
                logging.debug("after update: param {} -- grad {}".format(
                    next(model.parameters()).norm().item(),
                    next(model.parameters()).grad.norm().item()))

            n_items += args.batch_size

            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch_start

        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                logging.info("... State dict collected")

        measurements.append(n_items / epoch_runtime)
        if dist.get_rank() == 0:
            logging.info(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    training_stop = time.monotonic()
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
    logging.info(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint"
    )

    validate_benchmark(measurements, final_loss, args, check_regression)

    dist.destroy_process_group()  # type: ignore
Esempio n. 5
0
def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    backend: str = "gloo",
    use_oss: bool = True,
    use_sdp: bool = False,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
    reference_loss: float = -1.0,
):
    assert not use_sdp or (use_sdp
                           and use_oss), "ShardedDataParallel requires OSS"
    # DDP
    dist_init(rank=rank, world_size=world_size, backend=backend)

    # Setup
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None

    if use_sdp:
        ddp = ShardedDataParallel(
            module=model,
            optimizer=OPTIM,
            optimizer_params={
                "lr": 1e-4,
                "momentum": 0.9
            },
            world_size=world_size,
            broadcast_buffers=False,
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
            if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9))

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

                if use_sdp:
                    ddp.reduce(
                    )  # Send the gradients to the appropriate shards

                return loss

            final_loss = optimizer.step(closure)

        epoch_end = time.monotonic()

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                print("... State dict collected")

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2**20

    print(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall"
    )
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

    if use_oss and check_regression and dist.get_rank() == 0:
        assert (mean +
                3.0 * std) > reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) -
                   reference_loss) < 1e-3, "Loss regression detected"

        print("[Regression Test] VALID")
Esempio n. 6
0
def train(args, *, tbl):
    cfg, tokenizer, _, _ = nlp.models.bert.get_pretrained_bert(args.model_name, load_backbone=False,
                                                               load_mlm=False)
    cfg = nlp.torch.models.bert.BertModel.get_cfg().clone_merge(cfg)
    model = nlp.torch.models.bert.QTBertForPretrain(cfg)
    model.to(args.device)

    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        parameters_option(args.start_step, model, args, 'Loading')
    else:
        model.apply(nlp.torch.models.bert.init_weights)

    writer = None
    if args.local_rank in (-1, 0):
        writer = SummaryWriter(log_dir=os.path.join(args.ckpt_dir, 'tensorboard'))

    # pin_memory=False due to lack of https://github.com/pytorch/pytorch/commit/54ce171f16c8859f829dde09f87c364c8a6b4130
    sampler = RandomSampler(tbl) if args.local_rank == -1 else DistributedSampler(
        tbl, seed=args.seed)
    # batch_size // 2 for QuickThought
    train_dataloader = DataLoader(np.arange(len(tbl)), sampler=sampler,
                                  collate_fn=functools.partial(collate_fn, args=args, tbl=tbl),
                                  batch_size=args.batch_size // 2,
                                  num_workers=args.num_dataloader_workers, pin_memory=True)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer_arguments = {"lr": args.lr}
    if get_world_size(args) > 1 and args.ZeRO:
        optimizer = OSS(params=model.parameters(), optim=nlp.torch.optimizers.FusedLANS,
                        **optimizer_arguments)
        model = ShardedDataParallel(model, optimizer)
    elif get_world_size(args) > 1:
        optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters,
                                                   **optimizer_arguments)
        model = DistributedDataParallel(model, device_ids=[args.local_rank],
                                        output_device=args.local_rank, find_unused_parameters=True)
    else:
        optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters,
                                                   **optimizer_arguments)

    save_interval = args.ckpt_interval
    logging.info(f'#Total Training Steps={args.num_steps}, '
                 f'Warmup Steps={args.warmup_ratio * args.num_steps}, '
                 f'Save Interval={save_interval}')
    scheduler = nlp.torch.optimizers.schedules.get_warmup_linear_const_decay_poly_schedule(
        optimizer, total_steps=args.num_steps, warmup_ratio=args.warmup_ratio,
        const_ratio=args.const_ratio)

    if args.start_step:
        logging.info(f'Restart training from {args.start_step}')
        states_option(args.start_step, optimizer, args, 'Loading')

    ce_loss_fn = th.nn.CrossEntropyLoss()
    step_num = args.start_step
    if args.phase2:
        step_num -= args.phase1_num_steps
    running_num_tks, running_grad_norm = 0, 0
    running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0

    train_start_time = time.time()
    tic = time.time()
    model.zero_grad()
    if get_world_size(args) > 1 and args.ZeRO:
        scaler = ShardedGradScaler() if args.fp16 else None
    else:
        scaler = th.cuda.amp.GradScaler() if args.fp16 else None

    train_iter = repeat(train_dataloader, set_epoch=args.local_rank != -1)
    while step_num < args.num_steps:
        step_num += 1
        for accum_step in range(args.num_accumulated):
            (input_id, segment_id, valid_length, mlm_positions, mlm_labels) = next(train_iter)
            (input_id, segment_id, valid_length, mlm_positions,
             mlm_labels) = (arr.to(args.device) for arr in next(train_iter))

            model.train()
            accumulation = ((accum_step + 1) % args.num_accumulated != 0)
            with model.no_sync() if get_world_size(args) > 1 and accumulation else suppress():
                with th.cuda.amp.autocast(enabled=args.fp16):
                    _, pooled_out, mlm_scores, qt_similarity = model(input_id, segment_id,
                                                                     valid_length, mlm_positions)
                    mlm_loss = ce_loss_fn(mlm_scores, mlm_labels)
                    qt_label = th.arange(len(input_id) // 2, device=args.device)
                    qt_loss = ce_loss_fn(qt_similarity, qt_label)
                    loss = mlm_loss + qt_loss
                if args.num_accumulated > 1:
                    loss = loss / args.num_accumulated
                if args.fp16:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                with th.no_grad():
                    qt_acc = (qt_similarity.argmax(dim=1) == qt_label).sum() / (len(input_id) // 2)
                    mlm_acc = (mlm_scores.argmax(dim=1) == mlm_labels).sum() / len(mlm_labels)

            # Gather information from all workers for accurate statistics
            reduced_num_tokens = valid_length.sum()
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_num_tokens)
            reduced_num_mlm_tokens = th.tensor(len(mlm_labels), device=args.device)
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_num_mlm_tokens)
            reduced_loss_mlm = mlm_loss.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_loss_mlm)
            reduced_acc_mlm = mlm_acc.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_acc_mlm)
            reduced_bs = th.tensor(len(input_id), device=args.device)
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_bs)
            reduced_loss_qt = qt_loss.detach().clone() * len(input_id) / reduced_bs
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_loss_qt)
            reduced_acc_qt = qt_acc.detach().clone() * len(input_id) / reduced_bs
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_acc_qt)

            running_num_tks += reduced_num_tokens.item()
            running_mlm_loss += reduced_loss_mlm.item()
            running_mlm_acc += reduced_acc_mlm.item()
            running_qt_loss += reduced_loss_qt.item()
            running_qt_acc += reduced_acc_qt.item()

            if not accumulation:
                if args.fp16:
                    scaler.unscale_(optimizer)  # unscale for gradient clipping
                if get_world_size(args) > 1 and args.ZeRO:
                    total_norm = optimizer.clip_grad_norm(args.max_grad_norm)
                else:
                    total_norm = th.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    if get_world_size(args) > 1:
                        distributed.all_reduce(total_norm)
                        total_norm /= get_world_size(args)
                running_grad_norm += total_norm

                if args.fp16:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                with warnings.catch_warnings():
                    # Scheduler may warn if optimizer.step() call is skipped
                    # due to invalid gradients detected by scaler.
                    warnings.simplefilter("ignore", UserWarning)
                    scheduler.step()
                optimizer.zero_grad(set_to_none=True)

        if step_num % args.log_interval == 0:
            toc = time.time()
            wps = running_num_tks / (toc - tic)
            eta = (args.num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600
            interval = args.log_interval * args.num_accumulated
            logging.info(f'[Step {step_num}], LR={scheduler.get_last_lr()[0]:.6f}, '
                         f'Loss MLM/QT={running_mlm_loss / interval:.4f}/'
                         f'{running_qt_loss / interval:.4f}, '
                         f'Acc MLM/QT={running_mlm_acc / interval:.4f}/'
                         f'{running_qt_acc / interval:.4f}, '
                         f'Grad_norm={running_grad_norm / interval:.4f}, '
                         f'Time cost={toc - tic:.2f}, '
                         f'Throughput={wps:.2f} tokens/s, ETA={eta:.2f}h')
            if args.local_rank in (-1, 0):
                writer.add_scalar('Throughput_wps', wps, step_num)
                writer.add_scalar('Loss/MLM', running_mlm_loss / interval, step_num)
                writer.add_scalar('Loss/QT', running_qt_loss / interval, step_num)
                writer.add_scalar('Acc/MLM', running_mlm_acc / interval, step_num)
                writer.add_scalar('Acc/QT', running_qt_acc / interval, step_num)
                writer.add_scalar('LR', scheduler.get_last_lr()[0], step_num)
                writer.add_scalar('Grad_norm', running_grad_norm / interval, step_num)
            running_num_tks, running_grad_norm = 0, 0
            running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0
            tic = time.time()

        # Saving
        if step_num % save_interval == 0 or step_num >= args.num_steps:
            states_option(step_num, optimizer, args, 'Saving')
            if args.local_rank in (0, -1):
                parameters_option(step_num, model, args, 'Saving')

    logging.info('Finish training step: %d', step_num)
    train_end_time = time.time()
    logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time))

    if args.local_rank in (0, -1):
        save_dir = os.path.join(args.ckpt_dir, args.model_name)
        final_save(model, save_dir, tokenizer.vocab, cfg)
Esempio n. 7
0
def train(
    rank: int,
    args: argparse.Namespace,
    backend: str = "gloo",
    optim_type: OptimType = OptimType.vanilla,
    check_regression: bool = True,
):
    logging.basicConfig(
        level=logging.INFO if not args.debug else logging.DEBUG)

    # DDP
    dist_init(rank=rank, world_size=args.world_size, backend=backend)

    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device("cpu") if args.cpu else torch.device(rank)
    model, dataloader, loss_fn = get_problem(rank, args.world_size,
                                             args.batch_size, device,
                                             args.torchvision_model)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None
    model = cast(nn.Module, model)
    scaler = (TorchGradScaler() if args.optim_type == OptimType.vanilla else
              ShardedGradScaler()) if args.amp else None

    if optim_type == OptimType.oss_sharded_ddp:
        optimizer = OSS(params=model.parameters(),
                        optim=OPTIM,
                        lr=1e-4,
                        momentum=0.9)
        model = ShardedDDP(model, optimizer)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore
        optimizer = (OSS(
            params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
                     if optim_type == OptimType.oss_ddp else OPTIM(
                         model.parameters(), lr=1e-4, momentum=0.9))
    optimizer = cast(torch.optim.Optimizer, optimizer)

    # Reset the memory use counter
    if not args.cpu:
        torch.cuda.empty_cache()
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)

    # Standard training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0
    need_profiling = args.profile

    for epoch in range(args.epochs):
        n_items = 0
        epoch_runtime = 0.0

        for batch in dataloader:
            if not args.cpu:
                torch.cuda.synchronize(rank)
            batch__start = time.monotonic()

            def closure(data=batch, grad_scaler=None):
                model.zero_grad()
                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("\nbefore:  param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                if grad_scaler is not None:
                    # Automatically computes the FW pass in half precision
                    with torch.cuda.amp.autocast():
                        outputs = model(data["inputs"])
                        loss = loss_fn(outputs, data["label"])

                        # Accumulates scaled gradients.
                        grad_scaler.scale(loss).backward()
                else:
                    outputs = model(data["inputs"])
                    loss = loss_fn(outputs, data["label"])
                    loss.backward()

                if args.debug and rank == 0 and next(
                        model.parameters()).grad is not None:
                    logging.debug("after BW: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(),
                        next(model.parameters()).grad.norm().item()))
                return loss

            if need_profiling and not args.cpu:
                logging.info("Profiling the run")
                with profiler.profile(
                        use_cuda=True, record_shapes=True,
                        profile_memory=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        if scaler is not None:
                            final_loss = closure(
                                grad_scaler=scaler
                            )  # AMP scaler.step does not support closures
                            scaler.step(optimizer)
                            scaler.update()
                        else:
                            final_loss = optimizer.step(closure)

                prof.export_chrome_trace(
                    f"{optim_type}_trace_rank_{rank}.json")
                need_profiling = False  # only profile once

            else:
                if scaler is not None:
                    final_loss = closure(
                        grad_scaler=scaler
                    )  # AMP scaler.step does not support closures
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    final_loss = optimizer.step(closure)

            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(
                    next(model.buffers()).norm().item()))
                logging.debug("after update: param {} -- grad {}".format(
                    next(model.parameters()).norm().item(),
                    next(model.parameters()).grad.norm().item()))

            n_items += args.batch_size

            if not args.cpu:
                # make sure that the cuda kernels are finished before taking a timestamp
                torch.cuda.synchronize(rank)

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start

        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                logging.info("... State dict collected")

        measurements.append(n_items / epoch_runtime)
        if dist.get_rank() == 0:
            logging.info(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    max_memory = -1.0
    if not args.cpu:
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2**20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    training_stop = time.monotonic()
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
    logging.info(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint"
    )

    # Compute the median and median of absolute differences img per second
    measurements.sort()
    median = measurements[len(measurements) // 2]

    abs_diff = list(map(lambda x: abs(x - median), measurements))
    abs_diff.sort()
    mad = abs_diff[len(measurements) // 2] if args.epochs > 2 else -1

    logging.info(
        f"[{dist.get_rank()}] : Median speed: {median:.2f} +/- {mad:.2f}")

    if check_regression and dist.get_rank() == 0:
        assert (median +
                3.0 * mad) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) -
                   args.reference_loss) < 1e-3, "Loss regression detected"

        logging.info("[Regression Test] VALID")

    dist.destroy_process_group()  # type: ignore
Esempio n. 8
0
def train(
    rank: int,
    args: argparse.Namespace,
    backend: str = "gloo",
    optim_type: OptimType = OptimType.vanilla,
    check_regression: bool = True,
):
    logging.basicConfig(level=logging.INFO if not args.debug else logging.DEBUG)

    # DDP
    dist_init(rank=rank, world_size=args.world_size, backend=backend)

    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    device = torch.device("cpu") if args.cpu else torch.device(rank)
    model, dataloader, loss_fn = get_problem(rank, args.world_size, args.batch_size, device, args.torchvision_model)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None
    model = cast(nn.Module, model)

    if optim_type == OptimType.oss_sharded_ddp:
        model = ShardedDDP(
            model,
            optimizer=OPTIM,
            optimizer_params={"lr": 1e-4, "momentum": 0.9},
            world_size=args.world_size,
            broadcast_buffers=True,
        )
        optimizer = model.sharded_optimizer

    else:
        model = DDP(model, device_ids=[rank], find_unused_parameters=False)  # type: ignore
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
            if optim_type == OptimType.oss_ddp
            else OPTIM(model.parameters(), lr=1e-4, momentum=0.9)
        )
    optimizer = cast(torch.optim.Optimizer, optimizer)

    # Reset the memory use counter
    if not args.cpu:
        torch.cuda.reset_peak_memory_stats(rank)
        torch.cuda.synchronize(rank)

    # Standard training loop
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0
    need_profiling = args.profile

    for epoch in range(args.epochs):
        n_items = 0
        epoch_runtime = 0.0

        for batch in dataloader:
            batch__start = time.monotonic()

            def closure():
                model.zero_grad()
                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "\nbefore:  param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )

                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss.backward()

                if optim_type == OptimType.oss_sharded_ddp:
                    model.reduce()

                if args.debug and rank == 0 and next(model.parameters()).grad is not None:
                    logging.debug(
                        "after BW: param {} -- grad {}".format(
                            next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                        )
                    )
                return loss

            if need_profiling and not args.cpu:
                logging.info("Profiling the run")
                with profiler.profile(use_cuda=True, record_shapes=True, profile_memory=True) as prof:  # type: ignore
                    with profiler.record_function("batch"):
                        final_loss = optimizer.step(closure)
                        logging.info("profiling done")

                if rank == 0:
                    prof.export_chrome_trace(f"{optim_type}_trace.json")

                need_profiling = False  # only profile once

            else:
                final_loss = optimizer.step(closure)

            if args.debug and rank == 0:
                logging.debug("buffer: {}".format(next(model.buffers()).norm().item()))
                logging.debug(
                    "after update: param {} -- grad {}".format(
                        next(model.parameters()).norm().item(), next(model.parameters()).grad.norm().item()
                    )
                )

            n_items += args.batch_size

            batch_end = time.monotonic()
            epoch_runtime += batch_end - batch__start

        if optim_type == OptimType.oss_ddp or optim_type == OptimType.oss_sharded_ddp:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                logging.info("... State dict collected")

        measurements.append(n_items / epoch_runtime)
        if dist.get_rank() == 0:
            logging.info(f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}")

    max_memory = -1.0
    if not args.cpu:
        torch.cuda.synchronize(rank)
        max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20
        logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    training_stop = time.monotonic()
    img_per_sec = n_items / (training_stop - training_start) * args.epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2 ** 20

    logging.info(f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec inc. checkpoint")
    logging.info(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1)) if args.epochs > 2 else -1
    logging.info(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

    if check_regression and dist.get_rank() == 0:
        assert (mean + 3.0 * std) > args.reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * args.reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) - args.reference_loss) < 1e-3, "Loss regression detected"

        logging.info("[Regression Test] VALID")

    dist.destroy_process_group()  # type: ignore