def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_invalid_sync_all_reduce(device)
def test_distrib_single_device_xla(): device = idist.device() _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) _test_distrib_accumulator_device(device)
def test_multinode_distrib_gloo_cpu_or_gpu( distributed_context_multi_node_gloo): device = idist.device() _test_distrib_compute(device) _test_distrib_integration(device)
def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_compute(device) _test_distrib_integration(device)
def test_distrib_nccl_gpu(distributed_context_single_node_nccl): device = idist.device() _test_neptune_saver_integration(device)
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger): device = idist.device() # Setup Ignite trainer: # - let's define training step # - add other common handlers: # - TerminateOnNan, # - handler to setup learning rate scheduling, # - ModelCheckpoint # - RunningAverage` on `train_step` output # - Two progress bars on epochs and optionally on iterations with_amp = config["with_amp"] scaler = GradScaler(enabled=with_amp) def train_step(engine, batch): x, y = batch[0], batch[1] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) model.train() with autocast(enabled=with_amp): y_pred = model(x) loss = criterion(y_pred, y) optimizer.zero_grad() scaler.scale(loss).backward() scaler.step(optimizer) scaler.update() return { "batch loss": loss.item(), } trainer = Engine(train_step) trainer.logger = logger to_save = { "trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler } metric_names = [ "batch loss", ] common.setup_common_training_handlers( trainer=trainer, train_sampler=train_sampler, to_save=to_save, save_every_iters=config["checkpoint_every"], save_handler=get_save_handler(config), lr_scheduler=lr_scheduler, output_names=metric_names if config["log_every_iters"] > 0 else None, with_pbars=False, clear_cuda_cache=False, ) resume_from = config["resume_from"] if resume_from is not None: checkpoint_fp = Path(resume_from) assert checkpoint_fp.exists( ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found" logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}") checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu") Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint) return trainer
def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration(device) _test_distrib_accumulator_device(device)
def _test_distrib_binary_and_multilabel_inputs(device): rank = idist.get_rank() torch.manual_seed(12) def _test(y_pred, y, batch_size, metric_device): metric_device = torch.device(metric_device) ap = AveragePrecision(device=metric_device) torch.manual_seed(10 + rank) ap.reset() if batch_size > 1: n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size ap.update( (y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) else: ap.update((y_pred, y)) # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y = y.cpu().numpy() np_y_pred = y_pred.cpu().numpy() res = ap.compute() assert isinstance(res, float) assert average_precision_score(np_y, np_y_pred) == pytest.approx(res) def get_test_cases(): test_cases = [ # Binary input data of shape (N,) or (N, 1) (torch.randint(0, 2, size=(10, )).long(), torch.randint(0, 2, size=(10, )).long(), 1), (torch.randint(0, 2, size=(10, 1)).long(), torch.randint(0, 2, size=(10, 1)).long(), 1), # updated batches (torch.randint(0, 2, size=(50, )).long(), torch.randint(0, 2, size=(50, )).long(), 16), (torch.randint(0, 2, size=(50, 1)).long(), torch.randint(0, 2, size=(50, 1)).long(), 16), # Binary input data of shape (N, L) (torch.randint(0, 2, size=(10, 4)).long(), torch.randint(0, 2, size=(10, 4)).long(), 1), (torch.randint(0, 2, size=(10, 7)).long(), torch.randint(0, 2, size=(10, 7)).long(), 1), # updated batches (torch.randint(0, 2, size=(50, 4)).long(), torch.randint(0, 2, size=(50, 4)).long(), 16), (torch.randint(0, 2, size=(50, 7)).long(), torch.randint(0, 2, size=(50, 7)).long(), 16), ] return test_cases for _ in range(3): test_cases = get_test_cases() for y_pred, y, batch_size in test_cases: _test(y_pred, y, batch_size, "cpu") if device.type != "xla": _test(y_pred, y, batch_size, idist.device())
def _test_distrib_integration_binary_input(device): rank = idist.get_rank() torch.manual_seed(12) n_iters = 80 s = 16 n_classes = 2 offset = n_iters * s def _test(y_preds, y_true, n_epochs, metric_device, update_fn): metric_device = torch.device(metric_device) engine = Engine(update_fn) ap = AveragePrecision(device=metric_device) ap.attach(engine, "ap") data = list(range(n_iters)) engine.run(data=data, max_epochs=n_epochs) assert "ap" in engine.state.metrics res = engine.state.metrics["ap"] true_res = average_precision_score(y_true.cpu().numpy(), y_preds.cpu().numpy()) assert pytest.approx(res) == true_res def get_tests(is_N): if is_N: y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(), )).to(device) y_preds = torch.rand(offset * idist.get_world_size(), ).to(device) def update_fn(engine, i): return ( y_preds[i * s + rank * offset:(i + 1) * s + rank * offset], y_true[i * s + rank * offset:(i + 1) * s + rank * offset], ) else: y_true = torch.randint(0, n_classes, size=(offset * idist.get_world_size(), 10)).to(device) y_preds = torch.randint(0, n_classes, size=(offset * idist.get_world_size(), 10)).to(device) def update_fn(engine, i): return ( y_preds[i * s + rank * offset:(i + 1) * s + rank * offset, :], y_true[i * s + rank * offset:(i + 1) * s + rank * offset, :], ) return y_preds, y_true, update_fn metric_devices = ["cpu"] if device.type != "xla": metric_devices.append(idist.device()) for metric_device in metric_devices: for _ in range(2): # Binary input data of shape (N,) y_preds, y_true, update_fn = get_tests(is_N=True) _test(y_preds, y_true, n_epochs=1, metric_device=metric_device, update_fn=update_fn) _test(y_preds, y_true, n_epochs=2, metric_device=metric_device, update_fn=update_fn) # Binary input data of shape (N, L) y_preds, y_true, update_fn = get_tests(is_N=False) _test(y_preds, y_true, n_epochs=1, metric_device=metric_device, update_fn=update_fn) _test(y_preds, y_true, n_epochs=2, metric_device=metric_device, update_fn=update_fn)
def _test_func(index, ws, device): assert 0 <= index < ws assert ws == idist.get_world_size() assert device in idist.device().type
def _test_distrib_integration(device): from ignite.engine import Engine rank = idist.get_rank() chunks = [ (CAND_1, [REF_1A, REF_1B]), (CAND_2A, [REF_2A, REF_2B, REF_2C]), (CAND_2B, [REF_2A, REF_2B, REF_2C]), (CAND_1, [REF_1A]), (CAND_2A, [REF_2A, REF_2B]), (CAND_2B, [REF_2A, REF_2B]), (CAND_1, [REF_1B]), (CAND_2A, [REF_2B, REF_2C]), (CAND_2B, [REF_2B, REF_2C]), (CAND_1, [REF_1A, REF_1B]), (CAND_2A, [REF_2A, REF_2C]), (CAND_2B, [REF_2A, REF_2C]), (CAND_1, [REF_1A]), (CAND_2A, [REF_2A]), (CAND_2B, [REF_2C]), ] size = len(chunks) data = [] for c in chunks: data += idist.get_world_size() * [c] def update(_, i): candidate, references = data[i + size * rank] lower_split_references = [reference.lower().split() for reference in references] lower_split_candidate = candidate.lower().split() return lower_split_candidate, lower_split_references def _test(metric_device): engine = Engine(update) m = Rouge(variants=[1, 2, "L"], alpha=0.5, device=metric_device) m.attach(engine, "rouge") engine.run(data=list(range(size)), max_epochs=1) assert "rouge" in engine.state.metrics evaluator = pyrouge.Rouge( metrics=["rouge-n", "rouge-l"], max_n=4, apply_avg=True, apply_best=False, alpha=0.5, stemming=False, ensure_compatibility=False, ) rouge_1_f, rouge_2_f, rouge_l_f = (0, 0, 0) for candidate, references in data: scores = evaluator.get_scores([candidate], [references]) rouge_1_f += scores["rouge-1"]["f"] rouge_2_f += scores["rouge-2"]["f"] rouge_l_f += scores["rouge-l"]["f"] assert pytest.approx(engine.state.metrics["Rouge-1-F"], abs=1e-4) == rouge_1_f / len(data) assert pytest.approx(engine.state.metrics["Rouge-2-F"], abs=1e-4) == rouge_2_f / len(data) assert pytest.approx(engine.state.metrics["Rouge-L-F"], abs=1e-4) == rouge_l_f / len(data) _test("cpu") if device.type != "xla": _test(idist.device())
def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_creating_on_xla_fails(device) _test_invalid_sync_all_reduce(device)
def test_distrib_single_device_xla(): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_creating_on_xla_fails(device) _test_invalid_sync_all_reduce(device)
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): device = idist.device() _test_distrib_sync_all_reduce_decorator(device) _test_invalid_sync_all_reduce(device)
def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2()) _test_distrib_accumulator_device(device, y_test_1())
def test_distrib_single_device_xla(): device = idist.device() _test_distrib_binary_and_multilabel_inputs(device) _test_distrib_integration_binary_input(device)
def training(local_rank, config): rank = idist.get_rank() manual_seed(config["seed"] + rank) device = idist.device() logger = setup_logger(name="CIFAR10-Training") log_basic_info(logger, config) output_path = config["output_path"] if rank == 0: if config["stop_iteration"] is None: now = datetime.now().strftime("%Y%m%d-%H%M%S") else: now = f"stop-on-{config['stop_iteration']}" folder_name = f"{config['model']}_backend-{idist.backend()}-{idist.get_world_size()}_{now}" output_path = Path(output_path) / folder_name if not output_path.exists(): output_path.mkdir(parents=True) config["output_path"] = output_path.as_posix() logger.info(f"Output path: {config['output_path']}") if "cuda" in device.type: config["cuda device name"] = torch.cuda.get_device_name(local_rank) if config["with_clearml"]: try: from clearml import Task except ImportError: # Backwards-compatibility for legacy Trains SDK from trains import Task task = Task.init("CIFAR10-Training", task_name=output_path.stem) task.connect_configuration(config) # Log hyper parameters hyper_params = [ "model", "batch_size", "momentum", "weight_decay", "num_epochs", "learning_rate", "num_warmup_epochs", ] task.connect({k: config[k] for k in hyper_params}) # Setup dataflow, model, optimizer, criterion train_loader, test_loader = get_dataflow(config) config["num_iters_per_epoch"] = len(train_loader) model, optimizer, criterion, lr_scheduler = initialize(config) # Create trainer for current task trainer = create_trainer(model, optimizer, criterion, lr_scheduler, train_loader.sampler, config, logger) # Let's now setup evaluator engine to perform model's validation and compute metrics metrics = { "Accuracy": Accuracy(), "Loss": Loss(criterion), } # We define two evaluators as they wont have exactly similar roles: # - `evaluator` will save the best model based on validation score evaluator = create_evaluator(model, metrics=metrics, config=config) train_evaluator = create_evaluator(model, metrics=metrics, config=config) def run_validation(engine): epoch = trainer.state.epoch state = train_evaluator.run(train_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Train", state.metrics) state = evaluator.run(test_loader) log_metrics(logger, epoch, state.times["COMPLETED"], "Test", state.metrics) trainer.add_event_handler( Events.EPOCH_COMPLETED(every=config["validate_every"]) | Events.COMPLETED, run_validation) if rank == 0: # Setup TensorBoard logging on trainer and evaluators. Logged values are: # - Training metrics, e.g. running average loss values # - Learning rate # - Evaluation train/test metrics evaluators = {"training": train_evaluator, "test": evaluator} tb_logger = common.setup_tb_logging(output_path, trainer, optimizer, evaluators=evaluators) # Store 2 best models by validation accuracy starting from num_epochs / 2: best_model_handler = Checkpoint( {"model": model}, get_save_handler(config), filename_prefix="best", n_saved=2, global_step_transform=global_step_from_engine(trainer), score_name="test_accuracy", score_function=Checkpoint.get_default_score_fn("Accuracy"), ) evaluator.add_event_handler( Events.COMPLETED( lambda *_: trainer.state.epoch > config["num_epochs"] // 2), best_model_handler) # In order to check training resuming we can stop training on a given iteration if config["stop_iteration"] is not None: @trainer.on(Events.ITERATION_STARTED(once=config["stop_iteration"])) def _(): logger.info( f"Stop training on {trainer.state.iteration} iteration") trainer.terminate() try: trainer.run(train_loader, max_epochs=config["num_epochs"]) except Exception as e: logger.exception("") raise e if rank == 0: tb_logger.close()
def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_binary_and_multilabel_inputs(device) _test_distrib_integration_binary_input(device)
def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration(device) _test_distrib_accumulator_device(device)
def test_distrib_single_device_xla(): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device)
def test_distrib_single_device_xla(): device = idist.device() _test_distrib_compute(device) _test_distrib_integration(device)
def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device)
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): device = idist.device() _test_neptune_saver_integration(device)
def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo): device = idist.device() _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2()) _test_distrib_accumulator_device(device, y_test_1())
def _test_distrib_multilabel_input_NHW(device): # Multilabel input data of shape (N, C, H, W, ...) and (N, C, H, W, ...) rank = idist.get_rank() def _test(metric_device): metric_device = torch.device(metric_device) acc = Accuracy(is_multilabel=True, device=metric_device) torch.manual_seed(10 + rank) y_pred = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() y = torch.randint(0, 2, size=(4, 5, 8, 10), device=device).long() acc.update((y_pred, y)) assert ( acc._num_correct.device == metric_device ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = to_numpy_multilabel( y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) np_y = to_numpy_multilabel( y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) assert acc._type == "multilabel" n = acc._num_examples res = acc.compute() assert n * idist.get_world_size() == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) acc.reset() torch.manual_seed(10 + rank) y_pred = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long() y = torch.randint(0, 2, size=(4, 7, 10, 8), device=device).long() acc.update((y_pred, y)) assert ( acc._num_correct.device == metric_device ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = to_numpy_multilabel( y_pred.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) np_y = to_numpy_multilabel( y.cpu()) # (N, C, H, W, ...) -> (N * H * W ..., C) assert acc._type == "multilabel" n = acc._num_examples res = acc.compute() assert n * idist.get_world_size() == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) # check that result is not changed res = acc.compute() assert n * idist.get_world_size() == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) # Batched Updates acc.reset() torch.manual_seed(10 + rank) y_pred = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long() y = torch.randint(0, 2, size=(80, 5, 8, 10), device=device).long() batch_size = 16 n_iters = y.shape[0] // batch_size + 1 for i in range(n_iters): idx = i * batch_size acc.update((y_pred[idx:idx + batch_size], y[idx:idx + batch_size])) assert ( acc._num_correct.device == metric_device ), f"{type(acc._num_correct.device)}:{acc._num_correct.device} vs {type(metric_device)}:{metric_device}" # gather y_pred, y y_pred = idist.all_gather(y_pred) y = idist.all_gather(y) np_y_pred = to_numpy_multilabel( y_pred.cpu()) # (N, C, L, ...) -> (N * L * ..., C) np_y = to_numpy_multilabel(y.cpu()) # (N, C, L, ...) -> (N * L ..., C) assert acc._type == "multilabel" n = acc._num_examples res = acc.compute() assert n * idist.get_world_size() == acc._num_examples assert isinstance(res, float) assert accuracy_score(np_y, np_y_pred) == pytest.approx(res) # check multiple random inputs as random exact occurencies are rare for _ in range(3): _test("cpu") if device.type != "xla": _test(idist.device())
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): device = idist.device() _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2()) _test_distrib_accumulator_device(device, y_test_1())
def _test_distrib_xla_nprocs(index): device = idist.device() _test_distrib_multilabel_input_NHW(device) _test_distrib_integration_multiclass(device) _test_distrib_integration_multilabel(device) _test_distrib_accumulator_device(device)
def test_distrib_single_device_xla(): device = idist.device() _test_distrib_compute_on_criterion(device, y_test_1(), y_test_2()) _test_distrib_accumulator_device(device, y_test_1())
def test_multinode_distrib_nccl_gpu(distributed_context_multi_node_nccl): device = idist.device() _test_distrib_compute(device) _test_distrib_integration(device)
def training(local_rank, cfg): logger = setup_logger("FixMatch Training", distributed_rank=idist.get_rank()) if local_rank == 0: logger.info(cfg.pretty()) rank = idist.get_rank() manual_seed(cfg.seed + rank) device = idist.device() model, ema_model, optimizer, sup_criterion, lr_scheduler = utils.initialize(cfg) unsup_criterion = instantiate(cfg.solver.unsupervised_criterion) cta = get_default_cta() ( supervised_train_loader, test_loader, unsup_train_loader, cta_probe_loader, ) = utils.get_dataflow(cfg, cta=cta, with_unsup=True) def train_step(engine, batch): model.train() optimizer.zero_grad() x, y = batch["sup_batch"]["image"], batch["sup_batch"]["target"] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) weak_x, strong_x = ( batch["unsup_batch"]["image"], batch["unsup_batch"]["strong_aug"], ) if weak_x.device != device: weak_x = weak_x.to(device, non_blocking=True) strong_x = strong_x.to(device, non_blocking=True) # according to TF code: single forward pass on concat data: [x, weak_x, strong_x] le = 2 * engine.state.mu_ratio + 1 # Why interleave: https://github.com/google-research/fixmatch/issues/20#issuecomment-613010277 # We need to interleave due to multiple-GPU batch norm issues. Let's say we have to GPUs, and our batch is # comprised of labeled (L) and unlabeled (U) images. Let's use a batch size of 2 for making easier visually # in my following example. # # - Without interleaving, we have a batch LLUUUUUU...U (there are 14 U). When the batch is split to be passed # to both GPUs, we'll have two batches LLUUUUUU and UUUUUUUU. Note that all labeled examples ended up in batch1 # sent to GPU1. The problem here is that batch norm will be computed per batch and the moments will lack # consistency between batches. # # - With interleaving, by contrast, the two batches will be LUUUUUUU and LUUUUUUU. As you can notice the # batches have the same distribution of labeled and unlabeled samples and will therefore have more consistent # moments. # x_cat = interleave(torch.cat([x, weak_x, strong_x], dim=0), le) y_pred_cat = model(x_cat) y_pred_cat = deinterleave(y_pred_cat, le) idx1 = len(x) idx2 = idx1 + len(weak_x) y_pred = y_pred_cat[:idx1, ...] y_weak_preds = y_pred_cat[idx1:idx2, ...] # logits_weak y_strong_preds = y_pred_cat[idx2:, ...] # logits_strong # supervised learning: sup_loss = sup_criterion(y_pred, y) # unsupervised learning: y_weak_probas = torch.softmax(y_weak_preds, dim=1).detach() y_pseudo = y_weak_probas.argmax(dim=1) max_y_weak_probas, _ = y_weak_probas.max(dim=1) unsup_loss_mask = ( max_y_weak_probas >= engine.state.confidence_threshold ).float() unsup_loss = ( unsup_criterion(y_strong_preds, y_pseudo) * unsup_loss_mask ).mean() total_loss = sup_loss + engine.state.lambda_u * unsup_loss total_loss.backward() optimizer.step() return { "total_loss": total_loss.item(), "sup_loss": sup_loss.item(), "unsup_loss": unsup_loss.item(), "mask": unsup_loss_mask.mean().item(), # this should not be averaged for DDP } output_names = ["total_loss", "sup_loss", "unsup_loss", "mask"] trainer = trainers.create_trainer( train_step, output_names=output_names, model=model, ema_model=ema_model, optimizer=optimizer, lr_scheduler=lr_scheduler, supervised_train_loader=supervised_train_loader, test_loader=test_loader, cfg=cfg, logger=logger, cta=cta, unsup_train_loader=unsup_train_loader, cta_probe_loader=cta_probe_loader, ) trainer.state.confidence_threshold = cfg.ssl.confidence_threshold trainer.state.lambda_u = cfg.ssl.lambda_u trainer.state.mu_ratio = cfg.ssl.mu_ratio distributed = idist.get_world_size() > 1 @trainer.on(Events.ITERATION_COMPLETED(every=cfg.ssl.cta_update_every)) def update_cta_rates(): batch = trainer.state.batch x, y = batch["cta_probe_batch"]["image"], batch["cta_probe_batch"]["target"] if x.device != device: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) policies = batch["cta_probe_batch"]["policy"] ema_model.eval() with torch.no_grad(): y_pred = ema_model(x) y_probas = torch.softmax(y_pred, dim=1) # (N, C) if distributed: for y_proba, t, policy in zip(y_probas, y, policies): error = y_proba error[t] -= 1 error = torch.abs(error).sum() cta.update_rates(policy, 1.0 - 0.5 * error.item()) else: error_per_op = [] for y_proba, t, policy in zip(y_probas, y, policies): error = y_proba error[t] -= 1 error = torch.abs(error).sum() for k, bins in policy: error_per_op.append(pack_as_tensor(k, bins, error)) error_per_op = torch.stack(error_per_op) # all gather tensor_list = idist.all_gather(error_per_op) # update cta rates for t in tensor_list: k, bins, error = unpack_from_tensor(t) cta.update_rates([(k, bins),], 1.0 - 0.5 * error) epoch_length = cfg.solver.epoch_length num_epochs = cfg.solver.num_epochs if not cfg.debug else 2 try: trainer.run( supervised_train_loader, epoch_length=epoch_length, max_epochs=num_epochs ) except Exception as e: import traceback print(traceback.format_exc())