Example #1
0
def train(rank: int, world_size: int, epochs: int, use_oss: bool):

    # DDP
    dist_init(rank, world_size)

    # Problem statement
    model = getModel().to(rank)
    dataloader = getData()
    loss_fn = getLossFun()

    optimizer: Optional[Union[OSS, torch.optim.SGD]] = None

    if not use_oss:
        optimizer = torch.optim.SGD(params=model.parameters(), lr=1e-4)
    else:
        base_optimizer = torch.optim.SGD
        base_optimizer_arguments = {
            "lr": 1e-4
        }  # any optimizer specific arguments, LR, momentum, etc...
        optimizer = OSS(params=model.parameters(),
                        optim=base_optimizer,
                        default=base_optimizer_arguments)

    training_start = time.monotonic()
    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()

    for _ in range(epochs):
        for (data, target) in dataloader:
            data, target = data.to(rank), target.to(rank)

            # Train
            model.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, target)
            loss.backward()

            # if you want to clip the gradients / get the current max:
            max_norm = 1000.0
            norm_type = 1
            if not use_oss:
                _total_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm,
                    norm_type=norm_type)  # type: ignore
            else:
                optimizer = cast(OSS, optimizer)
                _total_norm = optimizer.clip_grad_norm(max_norm,
                                                       norm_type=norm_type)

            optimizer.step()

            print(f"Loss: {loss.item()}")

    training_end = time.monotonic()
    max_memory = torch.cuda.max_memory_allocated(rank)

    print(
        f"[{dist.get_rank()}] : Training done. {training_end-training_start:.2f} sec"
    )
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
Example #2
0
    def prepare(self, param_groups) -> None:
        assert (dist.is_initialized(
        )), "torch.distributed needs to be initialized to prepare this rank"

        def optimizer_constructor(param_groups: Any, *args, **kwargs):
            # ClassyOptimizer have deferred initialization, while OSS needs access to the
            # raw optimizer instance, hence the trampoline
            logging.debug("Building a ZeRO enabled optimizer")
            self.base_optimizer.prepare(param_groups)
            return self.base_optimizer.optimizer

        self.optimizer = OSS(params=param_groups, optim=optimizer_constructor)
Example #3
0
class ZeRO(ClassyOptimizer):
    def __init__(self, base_optimizer: ClassyOptimizer):
        """Wraps an arbitrary :class:`ClassyOptimizer <classy_vision.optim.ClassyOptimizer>`
        optimizer and shards its state as described by ZeRO_.
        ::
            opt = OSS(params, optim=torch.optim.Adam, lr=0.01)

        .. _ZeRO: https://arxiv.org/abs/1910.02054

        This instance holds all of the parameters for the model (in the .param_groups attribute)
        but relies on a wrapped optimizer, which only process an original shard of the parameters.
        Every step all the parameters are synced across the replicas. The Fairscale library is used
        https://github.com/facebookresearch/fairscale
        """

        assert (
            fairscale_available
        ), "The Fairscale library needs to be installed to use this optimizer."

        super().__init__()
        self.base_optimizer = base_optimizer

    def prepare(self, param_groups) -> None:
        assert (dist.is_initialized(
        )), "torch.distributed needs to be initialized to prepare this rank"

        def optimizer_constructor(param_groups: Any, *args, **kwargs):
            # ClassyOptimizer have deferred initialization, while OSS needs access to the
            # raw optimizer instance, hence the trampoline
            logging.debug("Building a ZeRO enabled optimizer")
            self.base_optimizer.prepare(param_groups)
            return self.base_optimizer.optimizer

        self.optimizer = OSS(params=param_groups, optim=optimizer_constructor)

    @classmethod
    def from_config(cls, config):
        return cls(base_optimizer=build_optimizer(config["base_optimizer"]))

    def on_epoch(self, where: float) -> None:
        # Run the normal LR schedulers
        super().on_epoch(where)

        # Materialize the optimizer state on the replica in charge of checkpointing
        logging.info("Consolidating sharded state on primary rank. Where: %d" %
                     where)
        self.consolidate_state_dict()

    def consolidate_state_dict(self) -> None:
        self.optimizer.consolidate_state_dict(
            recipient_rank=get_primary_rank())
Example #4
0
 def make_adam(params):
     if args.ddp_zero:
         return OSS(params=params,
                    optim=Adam,
                    group=get_data_parallel_group(),
                    lr=lr)
     else:
         return Adam(params, lr=lr)
Example #5
0
 def make_adam(model):
     if args.ddp_zero:
         return OSS(params=model.parameters(),
                    optim=Adam,
                    group=get_data_parallel_group(),
                    lr=lr)
     else:
         return Adam(model.parameters(), lr=lr)
Example #6
0
def train(rank: int, world_size: int, epochs: int, use_oss: bool):

    # DDP
    dist_init(rank, world_size)

    # Problem statement
    model = getModel().to(rank)
    dataloader = getData()
    loss_fn = getLossFun()

    base_optimizer_arguments = {
        "lr": 1e-4
    }  # any optimizer specific arguments, LR, momentum, etc...
    if ~use_oss:
        optimizer = torch.optim.SGD(params=model.parameters(),
                                    **base_optimizer_arguments)
    else:
        base_optimizer = torch.optim.SGD
        optimizer = OSS(params=model.parameters(),
                        optim=base_optimizer,
                        **base_optimizer_arguments)

    training_start = time.monotonic()
    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    for e in range(epochs):
        for (data, target) in dataloader:
            data, target = data.to(rank), target.to(rank)
            # Train
            model.zero_grad()
            outputs = model(data)
            loss = loss_fn(outputs, target)
            loss /= world_size
            loss.backward()
            optimizer.step()
            print(f"Loss: {loss.item()}")

    training_end = time.monotonic()
    max_memory = torch.cuda.max_memory_allocated(rank)

    print(
        f"[{dist.get_rank()}] : Training done. {training_end-training_start:.2f} sec"
    )
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")
Example #7
0
def build_optimizer(model, config):
    optimizer_config = config.optimizer
    if "type" not in optimizer_config:
        raise ValueError(
            "Optimizer attributes must have a 'type' key "
            "specifying the type of optimizer. "
            "(Custom or PyTorch, e.g. 'adam_w' or 'SGD')"
        )
    optimizer_type = optimizer_config.type

    if "params" not in optimizer_config:
        warnings.warn("optimizer attributes has no params defined, defaulting to {}.")

    params = optimizer_config.get("params", {})

    if hasattr(torch.optim, optimizer_type):
        optimizer_class = getattr(torch.optim, optimizer_type)
    else:
        optimizer_class = registry.get_optimizer_class(optimizer_type)
        if optimizer_class is None:
            raise ValueError(
                "No optimizer class of type {} present in "
                "either torch or registered to registry"
            )

    parameters = get_optimizer_parameters(model, config)

    if optimizer_config.get("enable_state_sharding", False):
        # TODO(vedanuj): Remove once OSS is moved to PT upstream
        try:
            from fairscale.optim.oss import OSS
        except ImportError:
            print(
                "Optimizer state sharding requires fairscale. "
                + "Install using pip install fairscale."
            )
            raise

        assert (
            is_dist_initialized()
        ), "Optimizer state sharding can only be used in distributed mode."

        is_fp16 = config.get("training", {}).get("fp16", False)
        optimizer = OSS(
            params=parameters, optim=optimizer_class, broadcast_fp16=is_fp16, **params
        )
    else:
        optimizer = optimizer_class(parameters, **params)
    return optimizer
Example #8
0
def train(rank: int, args, world_size: int, epochs: int):

    # DDP init example
    dist_init(rank, world_size)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    # Setup
    if not args.cpu:
        torch.cuda.set_device(rank)
        torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    # Problem statement
    model = NeuralNet(input_size=784, hidden_size=500, num_classes=10).to(rank)

    if args.use_ortmodule:
        print("Converting to ORTModule....")
        model = ORTModule(model)

    train_dataloader, test_dataloader = get_dataloader(args, rank,
                                                       args.batch_size)
    loss_fn = my_loss
    base_optimizer = torch.optim.SGD  # pick any pytorch compliant optimizer here
    base_optimizer_arguments = {
    }  # pass any optimizer specific arguments here, or directly below when instantiating OSS
    if args.use_sharded_optimizer:
        # Wrap the optimizer in its state sharding brethren
        optimizer = OSS(params=model.parameters(),
                        optim=base_optimizer,
                        lr=args.lr)

        # Wrap the model into ShardedDDP, which will reduce gradients to the proper ranks
        model = ShardedDDP(model, optimizer)
    else:
        device_ids = None if args.cpu else [rank]
        model = DDP(model, device_ids=device_ids,
                    find_unused_parameters=False)  # type: ignore

        optimizer = torch.optim.SGD(model.parameters(), lr=args.lr)
    # Any relevant training loop, nothing specific to OSS. For example:
    model.train()
    total_training_time, total_test_time, epoch_0_training, validation_accuracy = 0, 0, 0, 0
    for epoch in range(epochs):
        total_training_time += train_step(args, model, rank, optimizer,
                                          loss_fn, train_dataloader, epoch)
        if epoch == 0:
            epoch_0_training = total_training_time
        if args.test_batch_size > 0:
            test_time, validation_accuracy = test(args, model, rank, loss_fn,
                                                  test_dataloader)
            total_test_time += test_time

    print('\n======== Global stats ========')
    if args.use_ortmodule:
        estimated_export = 0
        if args.epochs > 1:
            estimated_export = epoch_0_training - (
                total_training_time - epoch_0_training) / (args.epochs - 1)
            print("  Estimated ONNX export took:               {:.4f}s".format(
                estimated_export))
        else:
            print(
                "  Estimated ONNX export took:               Estimate available when epochs > 1 only"
            )
        print("  Accumulated training without export took: {:.4f}s".format(
            total_training_time - estimated_export))
    print("  Accumulated training took:                {:.4f}s".format(
        total_training_time))
    print("  Accumulated validation took:              {:.4f}s".format(
        total_test_time))

    dist.destroy_process_group()
Example #9
0
def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    backend: str = "gloo",
    use_oss: bool = True,
    use_sdp: bool = False,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
    reference_loss: float = -1.0,
):
    assert not use_sdp or (use_sdp
                           and use_oss), "ShardedDataParallel requires OSS"
    # DDP
    dist_init(rank=rank, world_size=world_size, backend=backend)

    # Setup
    torch.cuda.set_device(rank)
    torch.cuda.manual_seed(0)
    torch.manual_seed(0)  # also sets the cuda seed
    np.random.seed(0)

    if backend == "nccl":
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

    # Shard the optimizer
    optimizer: Optional[torch.optim.Optimizer] = None

    if use_sdp:
        ddp = ShardedDataParallel(
            module=model,
            optimizer=OPTIM,
            optimizer_params={
                "lr": 1e-4,
                "momentum": 0.9
            },
            world_size=world_size,
            broadcast_buffers=False,
        )
        ddp.train()
        optimizer = ddp.optimizer
        model = ddp
    else:
        optimizer = (
            OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
            if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9))

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

                if use_sdp:
                    ddp.reduce(
                    )  # Send the gradients to the appropriate shards

                return loss

            final_loss = optimizer.step(closure)

        epoch_end = time.monotonic()

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                print("... State dict collected")

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss:.3f}"
            )

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2**20

    print(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall"
    )
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

    if use_oss and check_regression and dist.get_rank() == 0:
        assert (mean +
                3.0 * std) > reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
        assert abs(cast(float, final_loss) -
                   reference_loss) < 1e-3, "Loss regression detected"

        print("[Regression Test] VALID")
Example #10
0
def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    use_oss: bool = True,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
):

    # DDP
    dist_init(rank, world_size)

    # Standard RN101
    model = resnet101(pretrained=False, progress=True).to(rank)

    # Data setup, dummy data
    def collate(inputs: List[Any]):
        return {
            "inputs":
            torch.stack([i[0] for i in inputs]).to(torch.device(rank)),
            "label":
            torch.stack([i[1] for i in inputs]).to(torch.device(rank)),
        }

    dataloader = DataLoader(dataset=FakeData(transform=ToTensor(),
                                             size=data_size),
                            batch_size=batch_size,
                            collate_fn=collate)
    loss_fn = nn.CrossEntropyLoss()

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

    # Shard the optimizer
    optimizer: Union[OSS, OPTIM] = OSS(
        params=model.parameters(), optim=OPTIM, lr=1e-4,
        momentum=0.9) if use_oss else OPTIM(
            model.parameters(), lr=1e-4, momentum=0.9)

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                dist.all_reduce(loss, op=dist.ReduceOp.SUM)
                loss /= world_size
                loss.backward()
                return loss

            optimizer.step(closure)

        epoch_end = time.monotonic()

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            # optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                # _ = optimizer.state_dict()
                print("... State dict collected")

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec"
            )

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2**20

    print(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall"
    )
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

    if use_oss and check_regression and dist.get_rank() == 0:
        assert (mean +
                3.0 * std) > reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
        print("[Regression Test] VALID")
Example #11
0
def train(
    rank: int,
    world_size: int,
    num_epochs: int = 10,
    batch_size: int = 32,
    data_size: int = 200,
    backend: str = "gloo",
    use_oss: bool = True,
    check_regression: bool = True,
    reference_speed: float = -1.0,
    reference_memory: float = -1.0,
):
    # DDP
    dist_init(rank, world_size, backend)

    # Setup
    model, dataloader, loss_fn = get_problem(rank, data_size, batch_size)

    # Shard the optimizer
    optimizer: torch.optim.Optimizer = (
        OSS(params=model.parameters(), optim=OPTIM, lr=1e-4, momentum=0.9)
        if use_oss else OPTIM(model.parameters(), lr=1e-4, momentum=0.9))

    # Reset the memory use counter
    torch.cuda.reset_peak_memory_stats(rank)

    # Dummy training loop
    torch.cuda.synchronize(rank)
    training_start = time.monotonic()
    model.train()

    measurements = []
    final_loss: Optional[float] = -1.0

    for epoch in range(num_epochs):
        epoch_start = time.monotonic()

        for batch in dataloader:

            def closure():
                model.zero_grad()
                outputs = model(batch["inputs"])
                loss = loss_fn(outputs, batch["label"])
                loss /= world_size
                loss.backward()

                dist.all_reduce(loss, op=dist.ReduceOp.SUM)

                return loss

            final_loss = optimizer.step(closure)

        epoch_end = time.monotonic()

        if use_oss:
            # Check the checkpointing in the case of the OSS optimizer
            # Memory usage could spill over from there
            optimizer = cast(OSS, optimizer)
            optimizer.consolidate_state_dict()
            if dist.get_rank() == 0:
                _ = optimizer.state_dict()
                print("... State dict collected")

        measurements.append(data_size / (epoch_end - epoch_start))
        if dist.get_rank() == 0:
            print(
                f"Epoch {epoch} - processed {measurements[-1]:.2f} img per sec. Loss {final_loss}"
            )

    torch.cuda.synchronize(rank)
    training_stop = time.monotonic()
    img_per_sec = data_size / (training_stop - training_start) * num_epochs
    max_memory = torch.cuda.max_memory_allocated(rank) / 2**20

    print(
        f"[{dist.get_rank()}] : Training done. {img_per_sec:.2f} img per sec overall"
    )
    print(f"[{dist.get_rank()}] : Peak memory {max_memory:.1f}MiB")

    # Compute the mean and average img per second
    mean = sum(measurements) / len(measurements)
    diff = map(lambda x: pow(x - mean, 2.0), measurements)
    std = math.sqrt(sum(diff) / (len(measurements) - 1))
    print(f"[{dist.get_rank()}] : Mean speed: {mean:.2f} +/- {std:.2f}")

    if use_oss and check_regression and dist.get_rank() == 0:
        assert (mean +
                3.0 * std) > reference_speed, "Speed regression detected"
        assert max_memory < 1.05 * reference_memory, "Memory use regression detected"
        print("[Regression Test] VALID")
def train(args, *, tbl):
    cfg, tokenizer, _, _ = nlp.models.bert.get_pretrained_bert(args.model_name, load_backbone=False,
                                                               load_mlm=False)
    cfg = nlp.torch.models.bert.BertModel.get_cfg().clone_merge(cfg)
    model = nlp.torch.models.bert.QTBertForPretrain(cfg)
    model.to(args.device)

    if args.start_step:
        logging.info('Restart training from {}'.format(args.start_step))
        parameters_option(args.start_step, model, args, 'Loading')
    else:
        model.apply(nlp.torch.models.bert.init_weights)

    writer = None
    if args.local_rank in (-1, 0):
        writer = SummaryWriter(log_dir=os.path.join(args.ckpt_dir, 'tensorboard'))

    # pin_memory=False due to lack of https://github.com/pytorch/pytorch/commit/54ce171f16c8859f829dde09f87c364c8a6b4130
    sampler = RandomSampler(tbl) if args.local_rank == -1 else DistributedSampler(
        tbl, seed=args.seed)
    # batch_size // 2 for QuickThought
    train_dataloader = DataLoader(np.arange(len(tbl)), sampler=sampler,
                                  collate_fn=functools.partial(collate_fn, args=args, tbl=tbl),
                                  batch_size=args.batch_size // 2,
                                  num_workers=args.num_dataloader_workers, pin_memory=True)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    optimizer_arguments = {"lr": args.lr}
    if get_world_size(args) > 1 and args.ZeRO:
        optimizer = OSS(params=model.parameters(), optim=nlp.torch.optimizers.FusedLANS,
                        **optimizer_arguments)
        model = ShardedDataParallel(model, optimizer)
    elif get_world_size(args) > 1:
        optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters,
                                                   **optimizer_arguments)
        model = DistributedDataParallel(model, device_ids=[args.local_rank],
                                        output_device=args.local_rank, find_unused_parameters=True)
    else:
        optimizer = nlp.torch.optimizers.FusedLANS(optimizer_grouped_parameters,
                                                   **optimizer_arguments)

    save_interval = args.ckpt_interval
    logging.info(f'#Total Training Steps={args.num_steps}, '
                 f'Warmup Steps={args.warmup_ratio * args.num_steps}, '
                 f'Save Interval={save_interval}')
    scheduler = nlp.torch.optimizers.schedules.get_warmup_linear_const_decay_poly_schedule(
        optimizer, total_steps=args.num_steps, warmup_ratio=args.warmup_ratio,
        const_ratio=args.const_ratio)

    if args.start_step:
        logging.info(f'Restart training from {args.start_step}')
        states_option(args.start_step, optimizer, args, 'Loading')

    ce_loss_fn = th.nn.CrossEntropyLoss()
    step_num = args.start_step
    if args.phase2:
        step_num -= args.phase1_num_steps
    running_num_tks, running_grad_norm = 0, 0
    running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0

    train_start_time = time.time()
    tic = time.time()
    model.zero_grad()
    if get_world_size(args) > 1 and args.ZeRO:
        scaler = ShardedGradScaler() if args.fp16 else None
    else:
        scaler = th.cuda.amp.GradScaler() if args.fp16 else None

    train_iter = repeat(train_dataloader, set_epoch=args.local_rank != -1)
    while step_num < args.num_steps:
        step_num += 1
        for accum_step in range(args.num_accumulated):
            (input_id, segment_id, valid_length, mlm_positions, mlm_labels) = next(train_iter)
            (input_id, segment_id, valid_length, mlm_positions,
             mlm_labels) = (arr.to(args.device) for arr in next(train_iter))

            model.train()
            accumulation = ((accum_step + 1) % args.num_accumulated != 0)
            with model.no_sync() if get_world_size(args) > 1 and accumulation else suppress():
                with th.cuda.amp.autocast(enabled=args.fp16):
                    _, pooled_out, mlm_scores, qt_similarity = model(input_id, segment_id,
                                                                     valid_length, mlm_positions)
                    mlm_loss = ce_loss_fn(mlm_scores, mlm_labels)
                    qt_label = th.arange(len(input_id) // 2, device=args.device)
                    qt_loss = ce_loss_fn(qt_similarity, qt_label)
                    loss = mlm_loss + qt_loss
                if args.num_accumulated > 1:
                    loss = loss / args.num_accumulated
                if args.fp16:
                    scaler.scale(loss).backward()
                else:
                    loss.backward()

                with th.no_grad():
                    qt_acc = (qt_similarity.argmax(dim=1) == qt_label).sum() / (len(input_id) // 2)
                    mlm_acc = (mlm_scores.argmax(dim=1) == mlm_labels).sum() / len(mlm_labels)

            # Gather information from all workers for accurate statistics
            reduced_num_tokens = valid_length.sum()
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_num_tokens)
            reduced_num_mlm_tokens = th.tensor(len(mlm_labels), device=args.device)
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_num_mlm_tokens)
            reduced_loss_mlm = mlm_loss.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_loss_mlm)
            reduced_acc_mlm = mlm_acc.detach().clone() * len(mlm_labels) / reduced_num_mlm_tokens
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_acc_mlm)
            reduced_bs = th.tensor(len(input_id), device=args.device)
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_bs)
            reduced_loss_qt = qt_loss.detach().clone() * len(input_id) / reduced_bs
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_loss_qt)
            reduced_acc_qt = qt_acc.detach().clone() * len(input_id) / reduced_bs
            if get_world_size(args) > 1:
                distributed.all_reduce(reduced_acc_qt)

            running_num_tks += reduced_num_tokens.item()
            running_mlm_loss += reduced_loss_mlm.item()
            running_mlm_acc += reduced_acc_mlm.item()
            running_qt_loss += reduced_loss_qt.item()
            running_qt_acc += reduced_acc_qt.item()

            if not accumulation:
                if args.fp16:
                    scaler.unscale_(optimizer)  # unscale for gradient clipping
                if get_world_size(args) > 1 and args.ZeRO:
                    total_norm = optimizer.clip_grad_norm(args.max_grad_norm)
                else:
                    total_norm = th.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
                    if get_world_size(args) > 1:
                        distributed.all_reduce(total_norm)
                        total_norm /= get_world_size(args)
                running_grad_norm += total_norm

                if args.fp16:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()
                with warnings.catch_warnings():
                    # Scheduler may warn if optimizer.step() call is skipped
                    # due to invalid gradients detected by scaler.
                    warnings.simplefilter("ignore", UserWarning)
                    scheduler.step()
                optimizer.zero_grad(set_to_none=True)

        if step_num % args.log_interval == 0:
            toc = time.time()
            wps = running_num_tks / (toc - tic)
            eta = (args.num_steps - step_num) / (step_num / (toc - train_start_time)) / 3600
            interval = args.log_interval * args.num_accumulated
            logging.info(f'[Step {step_num}], LR={scheduler.get_last_lr()[0]:.6f}, '
                         f'Loss MLM/QT={running_mlm_loss / interval:.4f}/'
                         f'{running_qt_loss / interval:.4f}, '
                         f'Acc MLM/QT={running_mlm_acc / interval:.4f}/'
                         f'{running_qt_acc / interval:.4f}, '
                         f'Grad_norm={running_grad_norm / interval:.4f}, '
                         f'Time cost={toc - tic:.2f}, '
                         f'Throughput={wps:.2f} tokens/s, ETA={eta:.2f}h')
            if args.local_rank in (-1, 0):
                writer.add_scalar('Throughput_wps', wps, step_num)
                writer.add_scalar('Loss/MLM', running_mlm_loss / interval, step_num)
                writer.add_scalar('Loss/QT', running_qt_loss / interval, step_num)
                writer.add_scalar('Acc/MLM', running_mlm_acc / interval, step_num)
                writer.add_scalar('Acc/QT', running_qt_acc / interval, step_num)
                writer.add_scalar('LR', scheduler.get_last_lr()[0], step_num)
                writer.add_scalar('Grad_norm', running_grad_norm / interval, step_num)
            running_num_tks, running_grad_norm = 0, 0
            running_mlm_loss, running_qt_loss, running_mlm_acc, running_qt_acc = 0, 0, 0, 0
            tic = time.time()

        # Saving
        if step_num % save_interval == 0 or step_num >= args.num_steps:
            states_option(step_num, optimizer, args, 'Saving')
            if args.local_rank in (0, -1):
                parameters_option(step_num, model, args, 'Saving')

    logging.info('Finish training step: %d', step_num)
    train_end_time = time.time()
    logging.info('Train cost={:.1f} s'.format(train_end_time - train_start_time))

    if args.local_rank in (0, -1):
        save_dir = os.path.join(args.ckpt_dir, args.model_name)
        final_save(model, save_dir, tokenizer.vocab, cfg)