Ejemplo n.º 1
0
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

    device = idist.device()
    _test_distrib_sync_all_reduce_decorator(device)
    _test_invalid_sync_all_reduce(device)
Ejemplo n.º 2
0
def test_distrib_single_device_xla():
    device = idist.device()
    _test_distrib_multilabel_input_NHW(device)
    _test_distrib_integration_multiclass(device)
    _test_distrib_integration_multilabel(device)
    _test_distrib_accumulator_device(device)
Ejemplo n.º 3
0
def test_multinode_distrib_gloo_cpu_or_gpu(
        distributed_context_multi_node_gloo):

    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)
Ejemplo n.º 4
0
def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)
Ejemplo n.º 5
0
def test_distrib_nccl_gpu(distributed_context_single_node_nccl):

    device = idist.device()
    _test_neptune_saver_integration(device)
Ejemplo n.º 6
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(x)
            loss = criterion(y_pred, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Ejemplo n.º 7
0
def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_integration(device)
    _test_distrib_accumulator_device(device)
Ejemplo n.º 8
0
def _test_distrib_binary_and_multilabel_inputs(device):

    rank = idist.get_rank()
    torch.manual_seed(12)

    def _test(y_pred, y, batch_size, metric_device):

        metric_device = torch.device(metric_device)
        ap = AveragePrecision(device=metric_device)
        torch.manual_seed(10 + rank)

        ap.reset()

        if batch_size > 1:
            n_iters = y.shape[0] // batch_size + 1
            for i in range(n_iters):
                idx = i * batch_size
                ap.update(
                    (y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))
        else:
            ap.update((y_pred, y))

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y = y.cpu().numpy()
        np_y_pred = y_pred.cpu().numpy()

        res = ap.compute()
        assert isinstance(res, float)
        assert average_precision_score(np_y, np_y_pred) == pytest.approx(res)

    def get_test_cases():

        test_cases = [
            # Binary input data of shape (N,) or (N, 1)
            (torch.randint(0, 2, size=(10, )).long(),
             torch.randint(0, 2, size=(10, )).long(), 1),
            (torch.randint(0, 2, size=(10, 1)).long(),
             torch.randint(0, 2, size=(10, 1)).long(), 1),
            # updated batches
            (torch.randint(0, 2, size=(50, )).long(),
             torch.randint(0, 2, size=(50, )).long(), 16),
            (torch.randint(0, 2, size=(50, 1)).long(),
             torch.randint(0, 2, size=(50, 1)).long(), 16),
            # Binary input data of shape (N, L)
            (torch.randint(0, 2, size=(10, 4)).long(),
             torch.randint(0, 2, size=(10, 4)).long(), 1),
            (torch.randint(0, 2, size=(10, 7)).long(),
             torch.randint(0, 2, size=(10, 7)).long(), 1),
            # updated batches
            (torch.randint(0, 2, size=(50, 4)).long(),
             torch.randint(0, 2, size=(50, 4)).long(), 16),
            (torch.randint(0, 2, size=(50, 7)).long(),
             torch.randint(0, 2, size=(50, 7)).long(), 16),
        ]
        return test_cases

    for _ in range(3):
        test_cases = get_test_cases()
        for y_pred, y, batch_size in test_cases:
            _test(y_pred, y, batch_size, "cpu")
            if device.type != "xla":
                _test(y_pred, y, batch_size, idist.device())
Ejemplo n.º 9
0
def _test_distrib_integration_binary_input(device):

    rank = idist.get_rank()
    torch.manual_seed(12)
    n_iters = 80
    s = 16
    n_classes = 2
    offset = n_iters * s

    def _test(y_preds, y_true, n_epochs, metric_device, update_fn):
        metric_device = torch.device(metric_device)

        engine = Engine(update_fn)

        ap = AveragePrecision(device=metric_device)
        ap.attach(engine, "ap")

        data = list(range(n_iters))
        engine.run(data=data, max_epochs=n_epochs)

        assert "ap" in engine.state.metrics

        res = engine.state.metrics["ap"]

        true_res = average_precision_score(y_true.cpu().numpy(),
                                           y_preds.cpu().numpy())
        assert pytest.approx(res) == true_res

    def get_tests(is_N):
        if is_N:
            y_true = torch.randint(0,
                                   n_classes,
                                   size=(offset *
                                         idist.get_world_size(), )).to(device)
            y_preds = torch.rand(offset * idist.get_world_size(), ).to(device)

            def update_fn(engine, i):
                return (
                    y_preds[i * s + rank * offset:(i + 1) * s + rank * offset],
                    y_true[i * s + rank * offset:(i + 1) * s + rank * offset],
                )

        else:
            y_true = torch.randint(0,
                                   n_classes,
                                   size=(offset * idist.get_world_size(),
                                         10)).to(device)
            y_preds = torch.randint(0,
                                    n_classes,
                                    size=(offset * idist.get_world_size(),
                                          10)).to(device)

            def update_fn(engine, i):
                return (
                    y_preds[i * s + rank * offset:(i + 1) * s +
                            rank * offset, :],
                    y_true[i * s + rank * offset:(i + 1) * s +
                           rank * offset, :],
                )

        return y_preds, y_true, update_fn

    metric_devices = ["cpu"]
    if device.type != "xla":
        metric_devices.append(idist.device())
    for metric_device in metric_devices:
        for _ in range(2):
            # Binary input data of shape (N,)
            y_preds, y_true, update_fn = get_tests(is_N=True)
            _test(y_preds,
                  y_true,
                  n_epochs=1,
                  metric_device=metric_device,
                  update_fn=update_fn)
            _test(y_preds,
                  y_true,
                  n_epochs=2,
                  metric_device=metric_device,
                  update_fn=update_fn)
            # Binary input data of shape (N, L)
            y_preds, y_true, update_fn = get_tests(is_N=False)
            _test(y_preds,
                  y_true,
                  n_epochs=1,
                  metric_device=metric_device,
                  update_fn=update_fn)
            _test(y_preds,
                  y_true,
                  n_epochs=2,
                  metric_device=metric_device,
                  update_fn=update_fn)
Ejemplo n.º 10
0
def _test_func(index, ws, device):
    assert 0 <= index < ws
    assert ws == idist.get_world_size()
    assert device in idist.device().type
Ejemplo n.º 11
0
def _test_distrib_integration(device):

    from ignite.engine import Engine

    rank = idist.get_rank()

    chunks = [
        (CAND_1, [REF_1A, REF_1B]),
        (CAND_2A, [REF_2A, REF_2B, REF_2C]),
        (CAND_2B, [REF_2A, REF_2B, REF_2C]),
        (CAND_1, [REF_1A]),
        (CAND_2A, [REF_2A, REF_2B]),
        (CAND_2B, [REF_2A, REF_2B]),
        (CAND_1, [REF_1B]),
        (CAND_2A, [REF_2B, REF_2C]),
        (CAND_2B, [REF_2B, REF_2C]),
        (CAND_1, [REF_1A, REF_1B]),
        (CAND_2A, [REF_2A, REF_2C]),
        (CAND_2B, [REF_2A, REF_2C]),
        (CAND_1, [REF_1A]),
        (CAND_2A, [REF_2A]),
        (CAND_2B, [REF_2C]),
    ]

    size = len(chunks)

    data = []
    for c in chunks:
        data += idist.get_world_size() * [c]

    def update(_, i):
        candidate, references = data[i + size * rank]
        lower_split_references = [reference.lower().split() for reference in references]
        lower_split_candidate = candidate.lower().split()
        return lower_split_candidate, lower_split_references

    def _test(metric_device):
        engine = Engine(update)
        m = Rouge(variants=[1, 2, "L"], alpha=0.5, device=metric_device)
        m.attach(engine, "rouge")

        engine.run(data=list(range(size)), max_epochs=1)

        assert "rouge" in engine.state.metrics

        evaluator = pyrouge.Rouge(
            metrics=["rouge-n", "rouge-l"],
            max_n=4,
            apply_avg=True,
            apply_best=False,
            alpha=0.5,
            stemming=False,
            ensure_compatibility=False,
        )
        rouge_1_f, rouge_2_f, rouge_l_f = (0, 0, 0)
        for candidate, references in data:
            scores = evaluator.get_scores([candidate], [references])
            rouge_1_f += scores["rouge-1"]["f"]
            rouge_2_f += scores["rouge-2"]["f"]
            rouge_l_f += scores["rouge-l"]["f"]

        assert pytest.approx(engine.state.metrics["Rouge-1-F"], abs=1e-4) == rouge_1_f / len(data)
        assert pytest.approx(engine.state.metrics["Rouge-2-F"], abs=1e-4) == rouge_2_f / len(data)
        assert pytest.approx(engine.state.metrics["Rouge-L-F"], abs=1e-4) == rouge_l_f / len(data)

    _test("cpu")

    if device.type != "xla":
        _test(idist.device())
Ejemplo n.º 12
0
def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_sync_all_reduce_decorator(device)
    _test_creating_on_xla_fails(device)
    _test_invalid_sync_all_reduce(device)
Ejemplo n.º 13
0
def test_distrib_single_device_xla():
    device = idist.device()
    _test_distrib_sync_all_reduce_decorator(device)
    _test_creating_on_xla_fails(device)
    _test_invalid_sync_all_reduce(device)
Ejemplo n.º 14
0
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

    device = idist.device()
    _test_distrib_sync_all_reduce_decorator(device)
    _test_invalid_sync_all_reduce(device)
Ejemplo n.º 15
0
def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2())
    _test_distrib_accumulator_device(device, y_test_1())
Ejemplo n.º 16
0
def test_distrib_single_device_xla():

    device = idist.device()
    _test_distrib_binary_and_multilabel_inputs(device)
    _test_distrib_integration_binary_input(device)
Ejemplo n.º 17
0
def training(local_rank, config):

    rank = idist.get_rank()
    manual_seed(config["seed"] + rank)
    device = idist.device()

    logger = setup_logger(name="CIFAR10-Training")

    log_basic_info(logger, config)

    output_path = config["output_path"]
    if rank == 0:
        if config["stop_iteration"] is None:
            now = datetime.now().strftime("%Y%m%d-%H%M%S")
        else:
            now = f"stop-on-{config['stop_iteration']}"

        folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}"
        output_path = Path(output_path) / folder_name
        if not output_path.exists():
            output_path.mkdir(parents=True)
        config["output_path"] = output_path.as_posix()
        logger.info(f"Output path: {config['output_path']}")

        if "cuda" in device.type:
            config["cuda device name"] = torch.cuda.get_device_name(local_rank)

        if config["with_clearml"]:
            try:
                from clearml import Task
            except ImportError:
                # Backwards-compatibility for legacy Trains SDK
                from trains import Task

            task = Task.init("CIFAR10-Training", task_name=output_path.stem)
            task.connect_configuration(config)
            # Log hyper parameters
            hyper_params = [
                "model",
                "batch_size",
                "momentum",
                "weight_decay",
                "num_epochs",
                "learning_rate",
                "num_warmup_epochs",
            ]
            task.connect({k: config[k] for k in hyper_params})

    # Setup dataflow, model, optimizer, criterion
    train_loader, test_loader = get_dataflow(config)

    config["num_iters_per_epoch"] = len(train_loader)
    model, optimizer, criterion, lr_scheduler = initialize(config)

    # Create trainer for current task
    trainer = create_trainer(model, optimizer, criterion, lr_scheduler,
                             train_loader.sampler, config, logger)

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "Accuracy": Accuracy(),
        "Loss": Loss(criterion),
    }

    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    evaluator = create_evaluator(model, metrics=metrics, config=config)
    train_evaluator = create_evaluator(model, metrics=metrics, config=config)

    def run_validation(engine):
        epoch = trainer.state.epoch
        state = train_evaluator.run(train_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Train",
                    state.metrics)
        state = evaluator.run(test_loader)
        log_metrics(logger, epoch, state.times["COMPLETED"], "Test",
                    state.metrics)

    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=config["validate_every"])
        | Events.COMPLETED, run_validation)

    if rank == 0:
        # Setup TensorBoard logging on trainer and evaluators. Logged values are:
        #  - Training metrics, e.g. running average loss values
        #  - Learning rate
        #  - Evaluation train/test metrics
        evaluators = {"training": train_evaluator, "test": evaluator}
        tb_logger = common.setup_tb_logging(output_path,
                                            trainer,
                                            optimizer,
                                            evaluators=evaluators)

    # Store 2 best models by validation accuracy starting from num_epochs / 2:
    best_model_handler = Checkpoint(
        {"model": model},
        get_save_handler(config),
        filename_prefix="best",
        n_saved=2,
        global_step_transform=global_step_from_engine(trainer),
        score_name="test_accuracy",
        score_function=Checkpoint.get_default_score_fn("Accuracy"),
    )
    evaluator.add_event_handler(
        Events.COMPLETED(
            lambda *_: trainer.state.epoch > config["num_epochs"] // 2),
        best_model_handler)

    # In order to check training resuming we can stop training on a given iteration
    if config["stop_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"]))
        def _():
            logger.info(
                f"Stop training on {trainer.state.iteration} iteration")
            trainer.terminate()

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        logger.exception("")
        raise e

    if rank == 0:
        tb_logger.close()
Ejemplo n.º 18
0
def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_binary_and_multilabel_inputs(device)
    _test_distrib_integration_binary_input(device)
Ejemplo n.º 19
0
def test_distrib_single_device_xla():
    device = idist.device()
    _test_distrib_integration(device)
    _test_distrib_accumulator_device(device)
Ejemplo n.º 20
0
def test_distrib_single_device_xla():
    device = idist.device()
    _test_distrib_integration_multiclass(device)
    _test_distrib_integration_multilabel(device)
Ejemplo n.º 21
0
def test_distrib_single_device_xla():
    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)
Ejemplo n.º 22
0
def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_integration_multiclass(device)
    _test_distrib_integration_multilabel(device)
Ejemplo n.º 23
0
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

    device = idist.device()
    _test_neptune_saver_integration(device)
Ejemplo n.º 24
0
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):

    device = idist.device()
    _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2())
    _test_distrib_accumulator_device(device, y_test_1())
Ejemplo n.º 25
0
def _test_distrib_multilabel_input_NHW(device):
    # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...)

    rank = idist.get_rank()

    def _test(metric_device):
        metric_device = torch.device(metric_device)
        acc = Accuracy(is_multilabel=True, device=metric_device)

        torch.manual_seed(10 + rank)
        y_pred = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long()
        y = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long()
        acc.update((y_pred, y))

        assert (
            acc._num_correct.device == metric_device
        ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y_pred = to_numpy_multilabel(
            y_pred.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)
        np_y = to_numpy_multilabel(
            y.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)
        assert acc._type == "multilabel"
        n = acc._num_examples
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)

        acc.reset()
        torch.manual_seed(10 + rank)
        y_pred = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long()
        y = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long()
        acc.update((y_pred, y))

        assert (
            acc._num_correct.device == metric_device
        ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y_pred = to_numpy_multilabel(
            y_pred.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)
        np_y = to_numpy_multilabel(
            y.cpu())  # (N, C, H, W, ...) -> (N * H * W ..., C)

        assert acc._type == "multilabel"
        n = acc._num_examples
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)
        # check that result is not changed
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)

        # Batched Updates
        acc.reset()
        torch.manual_seed(10 + rank)
        y_pred = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long()
        y = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long()

        batch_size = 16
        n_iters = y.shape[0] // batch_size + 1

        for i in range(n_iters):
            idx = i * batch_size
            acc.update((y_pred[idx:idx + batch_size], y[idx:idx + batch_size]))

        assert (
            acc._num_correct.device == metric_device
        ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}"

        # gather y_pred, y
        y_pred = idist.all_gather(y_pred)
        y = idist.all_gather(y)

        np_y_pred = to_numpy_multilabel(
            y_pred.cpu())  # (N, C, L, ...) -> (N * L * ..., C)
        np_y = to_numpy_multilabel(y.cpu())  # (N, C, L, ...) -> (N * L ..., C)

        assert acc._type == "multilabel"
        n = acc._num_examples
        res = acc.compute()
        assert n * idist.get_world_size() == acc._num_examples
        assert isinstance(res, float)
        assert accuracy_score(np_y, np_y_pred) == pytest.approx(res)

    # check multiple random inputs as random exact occurencies are rare
    for _ in range(3):
        _test("cpu")
        if device.type != "xla":
            _test(idist.device())
Ejemplo n.º 26
0
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

    device = idist.device()
    _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2())
    _test_distrib_accumulator_device(device, y_test_1())
Ejemplo n.º 27
0
def _test_distrib_xla_nprocs(index):
    device = idist.device()
    _test_distrib_multilabel_input_NHW(device)
    _test_distrib_integration_multiclass(device)
    _test_distrib_integration_multilabel(device)
    _test_distrib_accumulator_device(device)
Ejemplo n.º 28
0
def test_distrib_single_device_xla():
    device = idist.device()
    _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2())
    _test_distrib_accumulator_device(device, y_test_1())
Ejemplo n.º 29
0
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl):

    device = idist.device()
    _test_distrib_compute(device)
    _test_distrib_integration(device)
Ejemplo n.º 30
0
def training(local_rank, cfg):

    logger = setup_logger("FixMatch Training", distributed_rank=idist.get_rank())

    if local_rank == 0:
        logger.info(cfg.pretty())

    rank = idist.get_rank()
    manual_seed(cfg.seed + rank)
    device = idist.device()

    model, ema_model, optimizer, sup_criterion, lr_scheduler = utils.initialize(cfg)

    unsup_criterion = instantiate(cfg.solver.unsupervised_criterion)

    cta = get_default_cta()

    (
        supervised_train_loader,
        test_loader,
        unsup_train_loader,
        cta_probe_loader,
    ) = utils.get_dataflow(cfg, cta=cta, with_unsup=True)

    def train_step(engine, batch):
        model.train()
        optimizer.zero_grad()

        x, y = batch["sup_batch"]["image"], batch["sup_batch"]["target"]
        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        weak_x, strong_x = (
            batch["unsup_batch"]["image"],
            batch["unsup_batch"]["strong_aug"],
        )
        if weak_x.device != device:
            weak_x = weak_x.to(device, non_blocking=True)
            strong_x = strong_x.to(device, non_blocking=True)

        # according to TF code: single forward pass on concat data: [x, weak_x, strong_x]
        le = 2 * engine.state.mu_ratio + 1
        # Why interleave: https://github.com/google-research/fixmatch/issues/20#issuecomment-613010277
        # We need to interleave due to multiple-GPU batch norm issues. Let's say we have to GPUs, and our batch is
        # comprised of labeled (L) and unlabeled (U) images. Let's use a batch size of 2 for making easier visually
        # in my following example.
        #
        # - Without interleaving, we have a batch LLUUUUUU...U (there are 14 U). When the batch is split to be passed
        # to both GPUs, we'll have two batches LLUUUUUU and UUUUUUUU. Note that all labeled examples ended up in batch1
        # sent to GPU1. The problem here is that batch norm will be computed per batch and the moments will lack
        # consistency between batches.
        #
        # - With interleaving, by contrast, the two batches will be LUUUUUUU and LUUUUUUU. As you can notice the
        # batches have the same distribution of labeled and unlabeled samples and will therefore have more consistent
        # moments.
        #
        x_cat = interleave(torch.cat([x, weak_x, strong_x], dim=0), le)
        y_pred_cat = model(x_cat)
        y_pred_cat = deinterleave(y_pred_cat, le)

        idx1 = len(x)
        idx2 = idx1 + len(weak_x)
        y_pred = y_pred_cat[:idx1, ...]
        y_weak_preds = y_pred_cat[idx1:idx2, ...]  # logits_weak
        y_strong_preds = y_pred_cat[idx2:, ...]  # logits_strong

        # supervised learning:
        sup_loss = sup_criterion(y_pred, y)

        # unsupervised learning:
        y_weak_probas = torch.softmax(y_weak_preds, dim=1).detach()
        y_pseudo = y_weak_probas.argmax(dim=1)
        max_y_weak_probas, _ = y_weak_probas.max(dim=1)
        unsup_loss_mask = (
            max_y_weak_probas >= engine.state.confidence_threshold
        ).float()
        unsup_loss = (
            unsup_criterion(y_strong_preds, y_pseudo) * unsup_loss_mask
        ).mean()

        total_loss = sup_loss + engine.state.lambda_u * unsup_loss

        total_loss.backward()

        optimizer.step()

        return {
            "total_loss": total_loss.item(),
            "sup_loss": sup_loss.item(),
            "unsup_loss": unsup_loss.item(),
            "mask": unsup_loss_mask.mean().item(),  # this should not be averaged for DDP
        }

    output_names = ["total_loss", "sup_loss", "unsup_loss", "mask"]

    trainer = trainers.create_trainer(
        train_step,
        output_names=output_names,
        model=model,
        ema_model=ema_model,
        optimizer=optimizer,
        lr_scheduler=lr_scheduler,
        supervised_train_loader=supervised_train_loader,
        test_loader=test_loader,
        cfg=cfg,
        logger=logger,
        cta=cta,
        unsup_train_loader=unsup_train_loader,
        cta_probe_loader=cta_probe_loader,
    )

    trainer.state.confidence_threshold = cfg.ssl.confidence_threshold
    trainer.state.lambda_u = cfg.ssl.lambda_u
    trainer.state.mu_ratio = cfg.ssl.mu_ratio

    distributed = idist.get_world_size() > 1

    @trainer.on(Events.ITERATION_COMPLETED(every=cfg.ssl.cta_update_every))
    def update_cta_rates():
        batch = trainer.state.batch
        x, y = batch["cta_probe_batch"]["image"], batch["cta_probe_batch"]["target"]
        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        policies = batch["cta_probe_batch"]["policy"]

        ema_model.eval()
        with torch.no_grad():
            y_pred = ema_model(x)
            y_probas = torch.softmax(y_pred, dim=1)  # (N, C)

            if distributed:
                for y_proba, t, policy in zip(y_probas, y, policies):
                    error = y_proba
                    error[t] -= 1
                    error = torch.abs(error).sum()
                    cta.update_rates(policy, 1.0 - 0.5 * error.item())
            else:
                error_per_op = []
                for y_proba, t, policy in zip(y_probas, y, policies):
                    error = y_proba
                    error[t] -= 1
                    error = torch.abs(error).sum()
                    for k, bins in policy:
                        error_per_op.append(pack_as_tensor(k, bins, error))
                error_per_op = torch.stack(error_per_op)
                # all gather
                tensor_list = idist.all_gather(error_per_op)
                # update cta rates
                for t in tensor_list:
                    k, bins, error = unpack_from_tensor(t)
                    cta.update_rates([(k, bins),], 1.0 - 0.5 * error)

    epoch_length = cfg.solver.epoch_length
    num_epochs = cfg.solver.num_epochs if not cfg.debug else 2
    try:
        trainer.run(
            supervised_train_loader, epoch_length=epoch_length, max_epochs=num_epochs
        )
    except Exception as e:
        import traceback

        print(traceback.format_exc())