Exemple #1
0
def test_asserts_setup_common_training_handlers():
    trainer = Engine(lambda e, b: None)

    with pytest.raises(
        ValueError,
        match=r"If to_save argument is provided then output_path or save_handler arguments should be also defined",
    ):
        setup_common_training_handlers(trainer, to_save={})

    with pytest.raises(ValueError, match=r"Arguments output_path and save_handler are mutually exclusive"):
        setup_common_training_handlers(trainer, to_save={}, output_path="abc", save_handler=lambda c, f, m: None)

    with pytest.warns(UserWarning, match=r"Argument train_sampler is a distributed sampler"):
        train_sampler = MagicMock(spec=DistributedSampler)
        setup_common_training_handlers(trainer, train_sampler=train_sampler)

    with pytest.raises(RuntimeError, match=r"This contrib module requires available GPU"):
        setup_common_training_handlers(trainer, with_gpu_stats=True)

    with pytest.raises(TypeError, match=r"Unhandled type of update_function's output."):
        trainer = Engine(lambda e, b: None)
        setup_common_training_handlers(
            trainer,
            output_names=["loss"],
            with_pbar_on_iters=False,
            with_pbars=False,
            with_gpu_stats=False,
            stop_on_nan=False,
            clear_cuda_cache=False,
        )
        trainer.run([1])
Exemple #2
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger):
    prepare_batch = config.prepare_batch
    device = config.device

    # Setup trainer
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform", lambda x: x)

    def train_update_function(engine, batch):

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=True)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y)

        if isinstance(loss, Mapping):
            assert "supervised batch loss" in loss
            loss_dict = loss
            output = {k: v.item() for k, v in loss_dict.items()}
            loss = loss_dict["supervised batch loss"] / accumulation_steps
        else:
            output = {"supervised batch loss": loss.item()}

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return output

    output_names = getattr(config, "output_names", ["supervised batch loss",])
    lr_scheduler = config.lr_scheduler

    trainer = Engine(train_update_function)
    trainer.logger = logger

    to_save = {"model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler, "trainer": trainer, "amp": amp}

    save_every_iters = getattr(config, "save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        with_gpu_stats=exp_tracking.has_mlflow,
        output_names=output_names,
        with_pbars=False,
    )

    if idist.get_rank() == 0:
        common.ProgressBar(persist=False).attach(trainer, metric_names="all")

    return trainer
Exemple #3
0
def test_no_warning_with_train_sampler(recwarn):
    from torch.utils.data import RandomSampler

    trainer = Engine(lambda e, b: None)
    train_sampler = RandomSampler([0, 1, 2])
    setup_common_training_handlers(trainer, train_sampler=train_sampler)
    assert len(recwarn) == 0, recwarn.pop()
Exemple #4
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger):
    prepare_batch = config.prepare_batch
    device = config.device

    # Setup trainer
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    def train_update_function(engine, batch):

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=True)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y) / accumulation_steps

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return {
            "supervised batch loss": loss.item(),
        }

    output_names = getattr(config, "output_names", ["supervised batch loss"])
    lr_scheduler = config.lr_scheduler

    trainer = Engine(train_update_function)
    trainer.logger = logger

    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer,
        "amp": amp
    }

    save_every_iters = getattr(config, "save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        output_path=config.output_path.as_posix(),
        lr_scheduler=lr_scheduler,
        with_gpu_stats=True,
        output_names=output_names,
        with_pbars=False,
    )

    common.ProgressBar(persist=False).attach(trainer, metric_names="all")

    return trainer
Exemple #5
0
def test_assert_setup_common_training_handlers_wrong_train_sampler(distributed_context_single_node_gloo):
    trainer = Engine(lambda e, b: None)

    from torch.utils.data.sampler import RandomSampler

    with pytest.raises(TypeError, match=r"Train sampler should be torch DistributedSampler"):
        train_sampler = RandomSampler([0, 1, 2, 3])
        setup_common_training_handlers(trainer, train_sampler)
Exemple #6
0
def _test_setup_common_training_handlers(dirname, device, rank=0, local_rank=0, distributed=False):

    lr = 0.01
    step_size = 100
    gamma = 0.5

    model = DummyModel().to(device)
    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[local_rank, ],
                                                          output_device=local_rank)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)

    def update_fn(engine, batch):
        optimizer.zero_grad()
        x = torch.tensor([batch], requires_grad=True, device=device)
        y_pred = model(x)
        loss = y_pred.mean()
        loss.backward()
        optimizer.step()
        return loss

    train_sampler = MagicMock()
    train_sampler.set_epoch = MagicMock()

    trainer = Engine(update_fn)
    setup_common_training_handlers(trainer, train_sampler=train_sampler,
                                   to_save={"model": model, "optimizer": optimizer},
                                   save_every_iters=75, output_path=dirname,
                                   lr_scheduler=lr_scheduler, with_gpu_stats=False,
                                   output_names=['batch_loss', ],
                                   with_pbars=True, with_pbar_on_iters=True, log_every_iters=50,
                                   device=device)

    num_iters = 100
    num_epochs = 10
    data = [i * 0.1 for i in range(num_iters)]
    trainer.run(data, max_epochs=num_epochs)

    # check handlers
    handlers = trainer._event_handlers[Events.ITERATION_COMPLETED]
    for cls in [TerminateOnNan, ]:
        assert any([isinstance(h[0], cls) for h in handlers]), "{}".format(handlers)
    assert 'batch_loss' in trainer.state.metrics

    # Check saved checkpoint
    if rank == 0:
        checkpoints = list(os.listdir(dirname))
        assert len(checkpoints) == 1
        for v in ["training_checkpoint", ]:
            assert any([v in c for c in checkpoints])

    # Check LR scheduling
    assert optimizer.param_groups[0]['lr'] <= lr * gamma ** (num_iters * num_epochs / step_size), \
        "{} vs {}".format(optimizer.param_groups[0]['lr'], lr * gamma ** (num_iters * num_epochs / step_size))
Exemple #7
0
def test_asserts_setup_common_training_handlers():
    trainer = Engine(lambda e, b: None)

    with pytest.raises(ValueError, match=r"If to_save argument is provided then output_path argument should be "
                                         r"also defined"):
        setup_common_training_handlers(trainer, to_save={})

    with pytest.warns(UserWarning, match=r"Argument train_sampler distributed sampler used to call "
                                         r"`set_epoch` method on epoch"):
        train_sampler = MagicMock()
        setup_common_training_handlers(trainer, train_sampler=train_sampler, with_gpu_stats=False)
Exemple #8
0
def test_asserts_setup_common_training_handlers():
    trainer = Engine(lambda e, b: None)

    with pytest.raises(
            ValueError,
            match=
            r"If to_save argument is provided then output_path or save_handler arguments should be also defined",
    ):
        setup_common_training_handlers(trainer, to_save={})

    with pytest.raises(
            ValueError,
            match=
            r"Arguments output_path and save_handler are mutually exclusive"):
        setup_common_training_handlers(trainer,
                                       to_save={},
                                       output_path="abc",
                                       save_handler=lambda c, f, m: None)

    with pytest.warns(
            UserWarning,
            match=r"Argument train_sampler is a distributed sampler"):
        train_sampler = MagicMock(spec=DistributedSampler)
        setup_common_training_handlers(trainer, train_sampler=train_sampler)

    with pytest.warns(UserWarning,
                      match=r"Argument device is unused and deprecated"):
        setup_common_training_handlers(trainer, device="cpu")
def create_trainer_and_evaluators(
    model: nn.Module,
    optimizer: Optimizer,
    criterion: nn.Module,
    data_loaders: Dict[str, DataLoader],
    metrics: Dict[str, Metric],
    config: ConfigSchema,
    logger: Logger,
) -> Tuple[Engine, Dict[str, Engine]]:
    trainer = get_trainer(model, criterion, optimizer)
    trainer.logger = logger

    evaluators = get_evaluators(model, metrics)
    setup_evaluation(trainer, evaluators, data_loaders, logger)

    lr_scheduler = get_lr_scheduler(config, optimizer, trainer, evaluators["val"])

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }

    common.setup_common_training_handlers(
        trainer=trainer,
        to_save=to_save,
        save_every_iters=config.checkpoint_every,
        save_handler=get_save_handler(config),
        with_pbars=False,
        train_sampler=data_loaders["train"].sampler,
    )
    trainer.add_event_handler(Events.EPOCH_STARTED, lr_scheduler)
    ProgressBar(persist=False).attach(
        trainer,
        metric_names="all",
        event_name=Events.ITERATION_COMPLETED(every=config.log_every_iters),
    )

    resume_from = config.resume_from
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix()
        )
        logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer, evaluators
Exemple #10
0
def create_trainer(model, optimizer, criterion, train_sampler, config, logger, with_clearml):
    device = config.device
    prepare_batch = data.prepare_image_mask

    # Setup trainer
    accumulation_steps = config.get("accumulation_steps", 1)
    model_output_transform = config.get("model_output_transform", lambda x: x)

    with_amp = config.get("with_amp", True)
    scaler = GradScaler(enabled=with_amp)

    def forward_pass(batch):
        model.train()
        x, y = prepare_batch(batch, device=device, non_blocking=True)
        with autocast(enabled=with_amp):
            y_pred = model(x)
            y_pred = model_output_transform(y_pred)
            loss = criterion(y_pred, y) / accumulation_steps
        return loss

    def amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        if engine.state.iteration % accumulation_steps == 0:
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()

    def hvd_amp_backward_pass(engine, loss):
        scaler.scale(loss).backward()
        optimizer.synchronize()
        with optimizer.skip_synchronize():
            scaler.step(optimizer)
            scaler.update()
        optimizer.zero_grad()

    if idist.backend() == "horovod" and with_amp:
        backward_pass = hvd_amp_backward_pass
    else:
        backward_pass = amp_backward_pass

    def training_step(engine, batch):
        loss = forward_pass(batch)
        output = {"supervised batch loss": loss.item()}
        backward_pass(engine, loss)
        return output

    trainer = Engine(training_step)
    trainer.logger = logger

    output_names = [
        "supervised batch loss",
    ]
    lr_scheduler = config.lr_scheduler

    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer,
        "amp": scaler,
    }

    save_every_iters = config.get("save_every_iters", 1000)

    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=save_every_iters,
        save_handler=utils.get_save_handler(config.output_path.as_posix(), with_clearml),
        lr_scheduler=lr_scheduler,
        output_names=output_names,
        with_pbars=not with_clearml,
        log_every_iters=1,
    )

    resume_from = config.get("resume_from", None)
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Exemple #11
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        input_ids = batch["input_ids"]
        attention_mask = batch["attention_mask"]
        token_type_ids = batch["token_type_ids"]
        labels = batch["label"].view(-1, 1)

        if input_ids.device != device:
            input_ids = input_ids.to(device,
                                     non_blocking=True,
                                     dtype=torch.long)
            attention_mask = attention_mask.to(device,
                                               non_blocking=True,
                                               dtype=torch.long)
            token_type_ids = token_type_ids.to(device,
                                               non_blocking=True,
                                               dtype=torch.long)
            labels = labels.to(device, non_blocking=True, dtype=torch.float)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(input_ids, attention_mask, token_type_ids)
            loss = criterion(y_pred, labels)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]
    if config["log_every_iters"] == 0:
        # Disable logging training metrics:
        metric_names = None
        config["log_every_iters"] = 15

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=utils.get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        log_every_iters=config["log_every_iters"],
        with_pbars=not config["with_clearml"],
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Exemple #12
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    cutmix_beta = config["cutmix_beta"]
    cutmix_prob = config["cutmix_prob"]
    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            r = torch.rand(1).item()
            if cutmix_beta > 0 and r < cutmix_prob:
                output, loss = utils.cutmix_forward(model, x, criterion, y,
                                                    cutmix_beta)
            else:
                output = model(x)
                loss = criterion(output, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()

        if idist.backend() == "horovod":
            optimizer.synchronize()
            with optimizer.skip_synchronize():
                scaler.step(optimizer)
                scaler.update()
        else:
            scaler.step(optimizer)
            scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    if config["with_pbar"] and idist.get_rank() == 0:
        ProgressBar().attach(trainer)

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert (checkpoint_fp.exists()
                ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Exemple #13
0
def main():
  parser = argparse.ArgumentParser()

  # Required parameters
  parser.add_argument("--model", type=str, default='ffn', help="model's name")
  parser.add_argument("--mode", type=int, choices=[0, 1, 2], default=None)
  parser.add_argument("--SNRdb", type=float, default=None)
  parser.add_argument("--pilot_version", type=int, choices=[1, 2], default=1)
  parser.add_argument("--loss_type", type=str, default="BCELoss")
  parser.add_argument("--train_batch_size", type=int, default=128)
  parser.add_argument("--valid_batch_size", type=int, default=128)
  parser.add_argument("--gradient_accumulation_steps", type=int, default=1)
  parser.add_argument("--max_norm", type=float, default=-1)
  parser.add_argument("--lr", type=float, default=1e-3)
  parser.add_argument("--noise_lambda", type=float, default=1.0)
  parser.add_argument("--lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default="linear")
  parser.add_argument("--reset_lr_scheduler", type=str, choices=["linear", "cycle", "cosine"], default=None)
  parser.add_argument("--reset_trainer", action='store_true')
  parser.add_argument("--modify_model", action='store_true')
  parser.add_argument("--wd", type=float, default=1e-4, help="weight decay")
  parser.add_argument("--eval_iter", type=int, default=10)
  parser.add_argument("--save_iter", type=int, default=10)
  parser.add_argument("--n_epochs", type=int, default=10)
  parser.add_argument("--flush_dataset", type=int, default=0)
  parser.add_argument("--no_cache", action='store_true')
  parser.add_argument("--with_pure_y", action='store_true') 
  parser.add_argument("--with_h", action='store_true') 
  parser.add_argument("--only_l1", action='store_true', help="Only loss 1")
  parser.add_argument("--interpolation", action='store_true', help="if interpolate between pure and reconstruction.") 
  parser.add_argument("--data_dir", type=str, default="data")
  parser.add_argument("--cache_dir", type=str, default="train_cache")
  parser.add_argument("--output_path", type=str, default="runs", help="model save")
  parser.add_argument("--resume_from", type=str, default=None, help="resume training.")
  parser.add_argument("--first_cache_index", type=int, default=0)
  parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu",
                      help="Device (cuda or cpu)")
  parser.add_argument("--local_rank", type=int, default=-1,
                      help="Local rank for distributed training (-1: not distributed)")
  parser.add_argument("--seed", type=int, default=43)
  parser.add_argument("--debug", action='store_true')
  args = parser.parse_args()

  args.output_path = os.path.join(args.output_path, f'pilot_{args.pilot_version}')
  args.cache_dir = os.path.join(args.data_dir, args.cache_dir)
  # Setup CUDA, GPU & distributed training
  args.distributed = (args.local_rank != -1)
  if not args.distributed:
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
  else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    torch.distributed.init_process_group(backend="nccl", init_method='env://')
  args.n_gpu = torch.cuda.device_count() if not args.distributed else 1
  args.device = device

  # Set seed
  set_seed(args)
  logger = setup_logger("trainer", distributed_rank=args.local_rank)

  # Model construction
  model = getattr(models, args.model)(args)
  model = model.to(device)
  optimizer = AdamW(model.parameters(), lr = args.lr, weight_decay=args.wd)

  if args.loss_type == "MSELoss":
    criterion = nn.MSELoss(reduction='sum').to(device)
  else:
    criterion = getattr(nn, args.loss_type, getattr(auxiliary, args.loss_type, None))().to(device)
  criterion2 = nn.MSELoss(reduction='sum').to(device)

  if args.local_rank != -1:
    model = torch.nn.parallel.DistributedDataParallel(
        model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True
    )

  train_dataset = SIGDataset(args, data_type="train")
  valid_dataset = SIGDataset(args, data_type="valid")
  train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) if args.distributed else None
  valid_sampler = torch.utils.data.distributed.DistributedSampler(valid_dataset) if args.distributed else None
  train_loader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size, pin_memory=True, shuffle=(not args.distributed))
  valid_loader = DataLoader(valid_dataset, sampler=valid_sampler, batch_size=args.valid_batch_size, pin_memory=True, shuffle=False)
  
  lr_scheduler = None
  if args.lr_scheduler == "linear":
    lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
  elif args.lr_scheduler == "cycle":
    lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader))
  elif args.lr_scheduler == "cosine":
    lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader))

  # Training function and trainer
  def update(engine, batch):
      model.train()
      y, x_label, y_pure, H = train_dataset.prepare_batch(batch, device=args.device)

      if args.with_pure_y and args.with_h:
        x_pred, y_pure_pred, H_pred = model(y, pure=y_pure, H=H, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        if args.loss_type == "MSELoss":
          loss_1 = loss_1 / x_pred.size(0)
        loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps
        loss_noise_h = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps
        if args.only_l1:
          loss = loss_1
        else:
          loss = loss_1 + loss_noise * args.noise_lambda + loss_noise_h
        output = (loss.item(), loss_1.item(), loss_noise.item(), loss_noise_h.item())
      elif args.with_pure_y:
        x_pred, y_pure_pred = model(y, pure=y_pure if args.interpolation else None, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss_noise = criterion2(y_pure_pred, y_pure) / y.size(0) / args.gradient_accumulation_steps
        loss = loss_1 + loss_noise * args.noise_lambda
        output = (loss.item(), loss_1.item(), loss_noise.item())
      elif args.with_h:
        x_pred, H_pred = model(y, opp=True)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss_noise = criterion2(H_pred, H) / H.size(0) / args.gradient_accumulation_steps
        loss = loss_1 + loss_noise * args.noise_lambda
        output = (loss.item(), loss_1.item(), loss_noise.item())
      else:
        x_pred = model(y)
        loss_1 = criterion(x_pred, x_label) / args.gradient_accumulation_steps
        loss = loss_1
        output = (loss.item(), loss_1.item(), torch.zeros_like(loss_1).item())

      loss.backward()
      if args.max_norm > 0:
        torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_norm)
      if engine.state.iteration % args.gradient_accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()
      return output
  trainer = Engine(update)

  to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
  metric_names = ["loss", "l1", "ln"]
  if args.with_pure_y and args.with_h:
    metric_names.append("lnH")

  common.setup_common_training_handlers(
    trainer=trainer,
    train_sampler=train_loader.sampler,
    to_save=to_save,
    save_every_iters=len(train_loader) * args.save_iter,
    lr_scheduler=lr_scheduler,
    output_names=metric_names,
    with_pbars=False,
    clear_cuda_cache=False,
    output_path=args.output_path,
    n_saved=2,
  )

  resume_from = args.resume_from
  if resume_from is not None:
    checkpoint_fp = Path(resume_from)
    assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
    logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
    checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
    if args.reset_trainer:
      to_save.pop("trainer")
    checkpoint_to_load = to_save if 'validation' not in resume_from else {"model": model}
    Checkpoint.load_objects(to_load=checkpoint_to_load, checkpoint=checkpoint)
    if args.reset_lr_scheduler is not None:
      if args.reset_lr_scheduler == "linear":
        lr_scheduler = PiecewiseLinear(optimizer, "lr", [(0, args.lr), (args.n_epochs * len(train_loader), 0.0)])
      elif args.reset_lr_scheduler == "cycle":
        lr_scheduler = LinearCyclicalScheduler(optimizer, 'lr', 0.0, args.lr, args.eval_iter * len(train_loader))
      elif args.reset_lr_scheduler == "cosine":
        lr_scheduler = CosineAnnealingScheduler(optimizer, 'lr', args.lr, 0.0, args.eval_iter * len(train_loader))

  metrics = {
    "accuracy": Accuracy(lambda output: (torch.round(output[0][0]), output[1][0])), 
    "loss_1": Loss(criterion, output_transform=lambda output: (output[0][0], output[1][0])),
    "loss_noise": Loss(criterion2, output_transform=lambda output: (output[0][1], output[1][1]))
  }
  if args.with_pure_y and args.with_h:
    metrics["loss_noise_h"] = Loss(criterion2, output_transform=lambda output: (output[0][2], output[1][2]))

  def _inference(engine: Engine, batch: Sequence[torch.Tensor]) -> Union[Any, Tuple[torch.Tensor]]:
    model.eval()
    with torch.no_grad():
      x, y, x_pure, H = valid_dataset.prepare_batch(batch, device=args.device, non_blocking=True)
      if args.with_pure_y and args.with_h:
        y_pred, x_pure_pred, h_pred = model(x, opp=True)
        outputs = (y_pred, x_pure_pred, h_pred), (y, x_pure, H)
      elif args.with_pure_y:
        y_pred, x_pure_pred = model(x, opp=True)
        outputs = (y_pred, x_pure_pred), (y, x_pure)
      elif args.with_h:
        y_pred, h_pred = model(x, opp=True)
        outputs = (y_pred, h_pred), (y, H)
      else:
        y_pred = model(x)
        x_pure_pred = x_pure
        outputs = (y_pred, x_pure_pred), (y, x_pure)       
      return outputs
  evaluator = Engine(_inference)
  for name, metric in metrics.items():
      metric.attach(evaluator, name)

  trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.eval_iter), lambda _: evaluator.run(valid_loader))

  if args.flush_dataset > 0:
    trainer.add_event_handler(Events.EPOCH_COMPLETED(every=args.n_epochs//args.flush_dataset), 
                  lambda _: train_loader.dataset.reset() if args.no_cache else train_loader.dataset.reload())

  # On the main process: add progress bar, tensorboard, checkpoints and save model, configuration and tokenizer before we start to train
  if args.local_rank in [-1, 0]:
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=metric_names, output_transform=lambda _: {"lr": f"{optimizer.param_groups[0]['lr']:.2e}"})
    evaluator.add_event_handler(Events.COMPLETED, lambda _: pbar.log_message("Validation: %s" % pformat(evaluator.state.metrics)))

    tb_logger = common.setup_tb_logging(args.output_path, trainer, optimizer, evaluators={'validation': evaluator}, log_every_iters=1)

  # Store 3 best models by validation accuracy:
  common.gen_save_best_models_by_val_score(
    save_handler=DiskSaver(args.output_path, require_empty=False),
    evaluator=evaluator,
    models={"model": model},
    metric_name="accuracy",
    n_saved=3,
    trainer=trainer,
    tag="validation"
  )

  # Run the training
  trainer.run(train_loader, max_epochs=args.n_epochs)

  if args.local_rank in [-1, 0]:
    tb_logger.close()
Exemple #14
0
def training(config,
             local_rank=None,
             with_mlflow_logging=False,
             with_plx_logging=False):

    if not getattr(config, "use_fp16", True):
        raise RuntimeError("This training script uses by default fp16 AMP")

    set_seed(config.seed + local_rank)
    torch.cuda.set_device(local_rank)
    device = "cuda"

    torch.backends.cudnn.benchmark = True

    train_loader = config.train_loader
    train_sampler = getattr(train_loader, "sampler", None)
    assert train_sampler is not None, "Train loader of type '{}' " "should have attribute 'sampler'".format(
        type(train_loader))
    assert hasattr(train_sampler, "set_epoch") and callable(
        train_sampler.set_epoch
    ), "Train sampler should have a callable method `set_epoch`"

    train_eval_loader = config.train_eval_loader
    val_loader = config.val_loader

    model = config.model.to(device)
    optimizer = config.optimizer
    model, optimizer = amp.initialize(model,
                                      optimizer,
                                      opt_level=getattr(
                                          config, "fp16_opt_level", "O2"),
                                      num_losses=1)
    model = DDP(model, delay_allreduce=True)
    criterion = config.criterion.to(device)

    prepare_batch = getattr(config, "prepare_batch", _prepare_batch)
    non_blocking = getattr(config, "non_blocking", True)

    # Setup trainer
    accumulation_steps = getattr(config, "accumulation_steps", 1)
    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    def train_update_function(engine, batch):

        model.train()

        x, y = prepare_batch(batch, device=device, non_blocking=non_blocking)
        y_pred = model(x)
        y_pred = model_output_transform(y_pred)
        loss = criterion(y_pred, y) / accumulation_steps

        with amp.scale_loss(loss, optimizer, loss_id=0) as scaled_loss:
            scaled_loss.backward()

        if engine.state.iteration % accumulation_steps == 0:
            optimizer.step()
            optimizer.zero_grad()

        return {
            "supervised batch loss": loss.item(),
        }

    trainer = Engine(train_update_function)

    lr_scheduler = config.lr_scheduler
    to_save = {
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
        "trainer": trainer
    }
    common.setup_common_training_handlers(
        trainer,
        train_sampler,
        to_save=to_save,
        save_every_iters=1000,
        output_path=config.output_path.as_posix(),
        lr_scheduler=lr_scheduler,
        with_gpu_stats=True,
        output_names=[
            "supervised batch loss",
        ],
        with_pbars=True,
        with_pbar_on_iters=with_mlflow_logging,
        log_every_iters=1,
    )

    if getattr(config, "benchmark_dataflow", False):
        benchmark_dataflow_num_iters = getattr(config,
                                               "benchmark_dataflow_num_iters",
                                               1000)
        DataflowBenchmark(benchmark_dataflow_num_iters,
                          prepare_batch=prepare_batch,
                          device=device).attach(trainer, train_loader)

    # Setup evaluators
    val_metrics = {
        "Accuracy": Accuracy(device=device),
        "Top-5 Accuracy": TopKCategoricalAccuracy(k=5, device=device),
    }

    if hasattr(config, "val_metrics") and isinstance(config.val_metrics, dict):
        val_metrics.update(config.val_metrics)

    model_output_transform = getattr(config, "model_output_transform",
                                     lambda x: x)

    evaluator_args = dict(
        model=model,
        metrics=val_metrics,
        device=device,
        non_blocking=non_blocking,
        prepare_batch=prepare_batch,
        output_transform=lambda x, y, y_pred: (
            model_output_transform(y_pred),
            y,
        ),
    )
    train_evaluator = create_supervised_evaluator(**evaluator_args)
    evaluator = create_supervised_evaluator(**evaluator_args)

    if dist.get_rank() == 0 and with_mlflow_logging:
        ProgressBar(persist=False,
                    desc="Train Evaluation").attach(train_evaluator)
        ProgressBar(persist=False, desc="Val Evaluation").attach(evaluator)

    def run_validation(_):
        train_evaluator.run(train_eval_loader)
        evaluator.run(val_loader)

    if getattr(config, "start_by_validation", False):
        trainer.add_event_handler(Events.STARTED, run_validation)
    trainer.add_event_handler(
        Events.EPOCH_COMPLETED(every=getattr(config, "val_interval", 1)),
        run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    score_metric_name = "Accuracy"

    if hasattr(config, "es_patience"):
        common.add_early_stopping_by_val_score(config.es_patience,
                                               evaluator,
                                               trainer,
                                               metric_name=score_metric_name)

    if dist.get_rank() == 0:

        tb_logger = common.setup_tb_logging(
            config.output_path.as_posix(),
            trainer,
            optimizer,
            evaluators={
                "training": train_evaluator,
                "validation": evaluator
            },
        )
        if with_mlflow_logging:
            common.setup_mlflow_logging(trainer,
                                        optimizer,
                                        evaluators={
                                            "training": train_evaluator,
                                            "validation": evaluator
                                        })

        if with_plx_logging:
            common.setup_plx_logging(trainer,
                                     optimizer,
                                     evaluators={
                                         "training": train_evaluator,
                                         "validation": evaluator
                                     })

        common.save_best_model_by_val_score(config.output_path.as_posix(),
                                            evaluator,
                                            model,
                                            metric_name=score_metric_name,
                                            trainer=trainer)

        # Log train/val predictions:
        tb_logger.attach(
            evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="validation"),
            event_name=Events.ITERATION_COMPLETED(once=len(val_loader) // 2),
        )

        tb_logger.attach(
            train_evaluator,
            log_handler=predictions_gt_images_handler(
                img_denormalize_fn=config.img_denormalize,
                n_images=15,
                another_engine=trainer,
                prefix_tag="training"),
            event_name=Events.ITERATION_COMPLETED(
                once=len(train_eval_loader) // 2),
        )

    trainer.run(train_loader, max_epochs=config.num_epochs)
    def setup(self):
        self._init_distribution()

        self.trainer = Engine(self.train_step)
        self.trainer.logger = setup_logger(name="trainer",
                                           distributed_rank=self.local_rank)
        self.log_basic_info(self.trainer.logger)

        self.load_trainer_from_checkpoint()

        if self.scheduler:
            self.scheduler_event = self.trainer.add_event_handler(
                Events.ITERATION_STARTED, self.scheduler)
        else:
            self.scheduler_event = None
        self.attach_metrics(self.trainer, self.train_metrics)

        if idist.get_world_size() > 1:

            def set_epoch(engine):
                self.train_loader.sampler.set_epoch(engine.state.epoch)

            self.trainer.add_event_handler(Events.EPOCH_STARTED, set_epoch)

        common.setup_common_training_handlers(
            self.trainer,
            train_sampler=self.train_loader.sampler,
            to_save=None,
            save_every_iters=0,
            output_path=None,
            lr_scheduler=None,
            output_names=None,
            with_pbars=self.hparams.add_pbar,
            clear_cuda_cache=True,
            stop_on_nan=False)

        self.evaluator = Engine(self.eval_step)
        self.evaluator.logger = setup_logger("evaluator",
                                             distributed_rank=self.local_rank)
        if self.hparams.add_pbar:
            ProgressBar(persist=False).attach(self.evaluator)

        def complete_clear(engine):
            engine.state.batch = None
            engine.state.output = None
            import gc
            gc.collect()

        self.trainer.add_event_handler(Events.EPOCH_COMPLETED, complete_clear)

        self.validation_handler_event = self.trainer.add_event_handler(
            Events.EPOCH_COMPLETED(every=self.hparams.eval_every),
            self.validate(self.valid_loader))
        self.evaluator.add_event_handler(Events.EPOCH_COMPLETED,
                                         complete_clear)

        train_handler_params = {
            "model": self.model,
            "optimizer": self.optimizer,
            "scheduler": self.scheduler
        }

        eval_handler_params = {
            "model": self.model,
            "optimizer": self.optimizer,
            "scheduler": self.scheduler
        }

        to_save = {
            "model": self.model,
            "trainer": self.trainer,
            "optimizer": self.optimizer
        }
        if self.scheduler is not None:
            to_save["scheduler"] = self.scheduler
        if USE_AMP:
            to_save["amp"] = amp
        self.attach_metrics(self.evaluator, self.validation_metrics)
        self.setup_checkpoint_saver(to_save)

        if self.rank == 0:
            self._init_logger()
            if self.logger:
                self.logger._init_logger(self.trainer, self.evaluator)
                self.logger._add_train_events(**train_handler_params)
                self.logger._add_eval_events(**eval_handler_params)
Exemple #16
0
def finetune_model(model_class,
                   project_path,
                   batch_size,
                   num_workers=0,
                   pin_memory=True,
                   non_blocking=True,
                   device=None,
                   base_lr=1e-4,
                   max_lr=1e-3,
                   lr_gamma=0.9,
                   lr_decay_iters=None,
                   weight_decay=0.0,
                   loss_func=None,
                   n_epochs=1,
                   patience=-1,
                   data_augmentation=True,
                   combination_module=simple_concatenation,
                   combination_size=KinshipClassifier.FACENET_OUT_SIZE * 2,
                   simple_fc_layers=None,
                   custom_fc_layers=None,
                   final_fc_layers=None,
                   train_ds_name=None,
                   dev_ds_name=None,
                   logging_rate=-1,
                   saving_rate=-1,
                   experiment_name=None,
                   checkpoint_name=None,
                   hof_size=1,
                   checkpoint_exp=None):
    if device is None:
        device = torch.device('cpu')

    if loss_func is None:
        loss_func = torch.nn.CrossEntropyLoss()

    if simple_fc_layers is None:
        simple_fc_layers = [1024]

    if custom_fc_layers is None:
        custom_fc_layers = [1024]

    if final_fc_layers is None:
        final_fc_layers = []

    if train_ds_name is None:
        train_ds_name = 'train_dataset.pkl'

    if dev_ds_name is None:
        dev_ds_name = 'dev_dataset.pkl'

    if checkpoint_exp is None:
        checkpoint_exp = experiment_name

    model = model_class(combination_module, combination_size, simple_fc_layers,
                        custom_fc_layers, final_fc_layers)

    data_path = os.path.join(project_path, 'data')
    processed_path = os.path.join(data_path, 'processed')

    dataset_names = {'train': train_ds_name, 'dev': dev_ds_name}
    dataset_paths = {
        partition: os.path.join(data_path, dataset_names[partition])
        for partition in dataset_names
    }
    raw_paths = {
        partition: os.path.join(processed_path, partition)
        for partition in dataset_paths
    }

    relationships_path = os.path.join(data_path, 'raw',
                                      'train_relationships.csv')
    datasets = {
        partition: KinshipDataset.get_dataset(
            dataset_paths[partition], raw_paths[partition], relationships_path,
            data_augmentation and (partition == 'train'))
        for partition in raw_paths
    }

    dataloaders = {
        partition: DataLoader(datasets[partition],
                              batch_size=batch_size,
                              shuffle=(partition == 'train'),
                              num_workers=num_workers,
                              pin_memory=pin_memory)
        for partition in datasets
    }

    params_to_train = list(
        filter(lambda x: x.requires_grad, model.parameters()))
    optimizer = optim.AdamW(params_to_train,
                            lr=base_lr,
                            weight_decay=weight_decay)
    lr_decay_iters = len(
        dataloaders['train']) if lr_decay_iters is None else lr_decay_iters
    lr_scheduler = optim.lr_scheduler.CyclicLR(optimizer,
                                               base_lr=base_lr,
                                               max_lr=max_lr,
                                               step_size_up=lr_decay_iters //
                                               2,
                                               mode='exp_range',
                                               gamma=lr_gamma,
                                               cycle_momentum=False)

    train_engine = create_supervised_trainer(model,
                                             optimizer,
                                             loss_fn=loss_func,
                                             device=device,
                                             non_blocking=non_blocking)

    if checkpoint_exp is not None and checkpoint_name is not None:
        experiment_dir = os.path.join(project_path, 'experiments',
                                      checkpoint_exp)
        model, optimizer, loss_func, lr_scheduler, train_engine = load_checkpoint(
            model_class, experiment_dir, checkpoint_name, device)

    eval_engine = create_supervised_evaluator(
        model,
        metrics=dict(accuracy=Accuracy(), cross_entropy=Loss(loss_func)),
        device=device,
        non_blocking=non_blocking)

    metrics = {}

    if logging_rate > 0:
        metrics['ce_history'] = []
        metrics['smoothed_loss_history'] = []
        beta = 0.98
        avg_loss = 0.0

        @train_engine.on(Events.ITERATION_COMPLETED(every=logging_rate))
        def log_iteration_training_metrics(engine):
            nonlocal metrics
            metrics['ce_history'].append(engine.state.output)

        @train_engine.on(Events.ITERATION_COMPLETED(every=logging_rate))
        def log_smoothed_lr(engine: Engine):
            nonlocal avg_loss
            avg_loss = (avg_loss * beta) + (engine.state.output * (1 - beta))
            metrics['smoothed_loss_history'].append(
                avg_loss /
                (1 - (beta**(len(metrics['smoothed_loss_history']) + 1))))

        @train_engine.on(Events.EPOCH_COMPLETED)
        def plot_metrics(engine):
            plot_metric(metrics['smoothed_loss_history'],
                        f"Smoothed loss epoch #{engine.state.epoch}",
                        "Cross Entropy",
                        index_scale=logging_rate)

    if patience >= 0:
        common.add_early_stopping_by_val_score(patience, eval_engine,
                                               train_engine, 'accuracy')

    # Replaced by setup_common_training_handlers
    # nan_terminate = TerminateOnNan()
    # train_engine.add_event_handler(Events.ITERATION_COMPLETED, nan_terminate)

    @train_engine.on(Events.EPOCH_COMPLETED)
    def print_training_metrics(engine):
        print(f"Finished epoch {engine.state.epoch}")
        if train_ds_name == dev_ds_name:
            print(f"Epoch {engine.state.epoch}: CE = {engine.state.output}")
            metrics['final_dev_loss'] = engine.state.output
            return
        eval_engine.run(dataloaders['dev'])
        metrics['final_dev_loss'] = eval_engine.state.metrics['cross_entropy']
        print(
            f"Epoch {engine.state.epoch}: CE = {eval_engine.state.metrics['cross_entropy']}, "
            f"Acc = {eval_engine.state.metrics['accuracy']}")

    # Replaced by setup_common_training_handlers
    # @train_engine.on(Events.ITERATION_COMPLETED)
    # def change_lr(engine):
    #     lr_scheduler.step()

    to_save = None
    output_path = None

    if saving_rate > 0:
        if experiment_name is None:
            print("Warning: saving rate specified but experiment name is None")
            exit()

        experiment_path = os.path.join(project_path, 'experiments',
                                       experiment_name)
        if not os.path.isdir(experiment_path):
            os.mkdir(experiment_path)

        with open(os.path.join(experiment_path, 'model.config'),
                  'wb+') as config_file:
            pickle.dump(model.get_configuration(), config_file)

        to_save = {
            'model': model,
            'optimizer': optimizer,
            'loss_func': loss_func,
            'lr_scheduler': lr_scheduler,
            'train_engine': train_engine
        }
        output_path = experiment_path

        best_models_dir = os.path.join(output_path, 'best_models')
        if not os.path.isdir(best_models_dir):
            os.mkdir(best_models_dir)
        common.save_best_model_by_val_score(best_models_dir,
                                            eval_engine,
                                            model,
                                            'accuracy',
                                            n_saved=hof_size,
                                            trainer=train_engine,
                                            tag='acc')

        # Replaced by setup_common_training_handlers
        # checkpointer = ModelCheckpoint(experiment_path, 'iter', n_saved=50,
        #                                global_step_transform=lambda engine, _:
        #                                f"{engine.state.epoch}-{engine.state.iteration}", require_empty=False)
        # train_engine.add_event_handler(Events.ITERATION_COMPLETED(every=saving_rate), checkpointer, to_save)

    common.setup_common_training_handlers(train_engine,
                                          to_save=to_save,
                                          save_every_iters=saving_rate,
                                          output_path=output_path,
                                          lr_scheduler=lr_scheduler,
                                          with_pbars=True,
                                          with_pbar_on_iters=True,
                                          log_every_iters=1,
                                          device=device)

    # Replaced by setup_common_training_handlers
    # train_pbar = ProgressBar()
    # train_pbar.attach(train_engine)
    #
    eval_pbar = ProgressBar(persist=False, desc="Evaluation")
    eval_pbar.attach(eval_engine)

    print(model)
    print("Running on:", device)
    train_engine.run(dataloaders['train'], max_epochs=n_epochs)

    return model, metrics
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()
        y_pred = model(x)

        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        if (config["log_every_iters"] > 0 and
            (engine.state.iteration - 1) % config["log_every_iters"] == 0):
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss

        return {"batch loss": batch_loss}

    trainer = Engine(train_step)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = ["batch loss"]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(
        ), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Exemple #18
0
def _test_setup_common_training_handlers(
    dirname, device, rank=0, local_rank=0, distributed=False, lr_scheduler=None, save_handler=None
):

    lr = 0.01
    step_size = 100
    gamma = 0.5
    num_iters = 100
    num_epochs = 10

    model = DummyModel().to(device)
    if distributed and "cuda" in device:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[local_rank,], output_device=local_rank)
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)

    if lr_scheduler is None:
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma)
    elif isinstance(lr_scheduler, str) and lr_scheduler == "ignite|LRScheduler":
        from ignite.contrib.handlers import LRScheduler

        lr_scheduler = LRScheduler(torch.optim.lr_scheduler.StepLR(optimizer, step_size=step_size, gamma=gamma))
    elif isinstance(lr_scheduler, str) and lr_scheduler == "ignite":
        from ignite.contrib.handlers import PiecewiseLinear

        milestones_values = [(0, 0.0), (step_size, lr), (num_iters * (num_epochs - 1), 0.0)]
        lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)
    else:
        raise ValueError(f"Unknown lr_scheduler: {lr_scheduler}")

    def update_fn(engine, batch):
        optimizer.zero_grad()
        x = torch.tensor([batch], requires_grad=True, device=device)
        y_pred = model(x)
        loss = y_pred.mean()
        loss.backward()
        optimizer.step()
        return loss

    train_sampler = None
    if distributed and idist.get_world_size() > 1:
        train_sampler = MagicMock(spec=DistributedSampler)
        train_sampler.set_epoch = MagicMock()

    trainer = Engine(update_fn)
    setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save={"model": model, "optimizer": optimizer},
        save_every_iters=75,
        output_path=dirname,
        save_handler=save_handler,
        lr_scheduler=lr_scheduler,
        with_gpu_stats=False,
        output_names=["batch_loss",],
        with_pbars=True,
        with_pbar_on_iters=True,
        log_every_iters=50,
    )

    data = [i * 0.1 for i in range(num_iters)]
    trainer.run(data, max_epochs=num_epochs)

    # check handlers
    handlers = trainer._event_handlers[Events.ITERATION_COMPLETED]
    for cls in [
        TerminateOnNan,
    ]:
        assert any([isinstance(h[0], cls) for h in handlers]), f"{handlers}"
    assert "batch_loss" in trainer.state.metrics

    # Check saved checkpoint
    if rank == 0:
        if save_handler is not None:
            dirname = save_handler.dirname
        checkpoints = list(os.listdir(dirname))
        assert len(checkpoints) == 1
        for v in [
            "training_checkpoint",
        ]:
            assert any([v in c for c in checkpoints])

    # Check LR scheduling
    assert optimizer.param_groups[0]["lr"] <= lr * gamma ** (
        num_iters * num_epochs / step_size
    ), f"{optimizer.param_groups[0]['lr']} vs {lr * gamma ** (num_iters * num_epochs / step_size)}"
Exemple #19
0
def test_setup_common_training_handlers(dirname, capsys):

    lr = 0.01
    step_size = 100
    gamma = 0.5

    model = DummyModel()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                   step_size=step_size,
                                                   gamma=gamma)

    def update_fn(engine, batch):
        optimizer.zero_grad()
        x = torch.tensor([batch], requires_grad=True)
        y_pred = model(x)
        loss = y_pred.mean()
        loss.backward()
        optimizer.step()
        return loss

    trainer = Engine(update_fn)
    setup_common_training_handlers(trainer,
                                   to_save={
                                       "model": model,
                                       "optimizer": optimizer
                                   },
                                   save_every_iters=75,
                                   output_path=dirname,
                                   lr_scheduler=lr_scheduler,
                                   with_gpu_stats=False,
                                   output_names=[
                                       'batch_loss',
                                   ],
                                   with_pbars=True,
                                   with_pbar_on_iters=True,
                                   log_every_iters=50)

    num_iters = 100
    num_epochs = 10
    data = [i * 0.1 for i in range(num_iters)]
    trainer.run(data, max_epochs=num_epochs)

    # check handlers
    handlers = trainer._event_handlers[Events.ITERATION_COMPLETED]
    for cls in [
            TerminateOnNan,
    ]:
        assert any([isinstance(h[0], cls) for h in handlers]), \
            "{}".format(trainer._event_handlers[Events.ITERATION_COMPLETED])
    assert 'batch_loss' in trainer.state.metrics

    # Check epoch-wise pbar
    captured = capsys.readouterr()
    out = captured.err.split('\r')
    out = list(map(lambda x: x.strip(), out))
    out = list(filter(None, out))
    assert u"Epoch:" in out[-1], "{}".format(out[-1])

    # Check saved checkpoint
    checkpoints = list(os.listdir(dirname))
    assert len(checkpoints) == 1
    for v in [
            "training_checkpoint",
    ]:
        assert any([v in c for c in checkpoints])

    # Check LR scheduling
    assert optimizer.param_groups[0]['lr'] <= lr * gamma ** (num_iters * num_epochs / step_size), \
        "{} vs {}".format(optimizer.param_groups[0]['lr'], lr * gamma ** (num_iters * num_epochs / step_size))
Exemple #20
0
def run(output_path, config):

    distributed = dist.is_available() and dist.is_initialized()
    rank = dist.get_rank() if distributed else 0

    manual_seed(config["seed"] + rank)

    # Setup dataflow, model, optimizer, criterion
    train_loader, test_loader = utils.get_dataflow(config, distributed)
    model, optimizer = utils.get_model_optimizer(config, distributed)
    criterion = nn.CrossEntropyLoss().to(utils.device)

    le = len(train_loader)
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer, param_name="lr", milestones_values=milestones_values)

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    def train_step(engine, batch):

        x = convert_tensor(batch[0], device=utils.device, non_blocking=True)
        y = convert_tensor(batch[1], device=utils.device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    if config["deterministic"] and rank == 0:
        print("Setup deterministic trainer")
    trainer = Engine(train_step) if not config["deterministic"] else DeterministicEngine(train_step)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config["display_iters"],
        log_every_iters=10,
    )

    if rank == 0:
        # Setup Tensorboard logger - wrapper on SummaryWriter
        tb_logger = TensorboardLogger(log_dir=output_path)
        # Attach logger to the trainer and log trainer's metrics (stored in trainer.state.metrics) every iteration
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="train", metric_names=metric_names),
            event_name=Events.ITERATION_COMPLETED,
        )
        # log optimizer's parameters: "lr" every iteration
        tb_logger.attach(
            trainer, log_handler=OptimizerParamsHandler(optimizer, param_name="lr"), event_name=Events.ITERATION_STARTED
        )

    # Let's now setup evaluator engine to perform model's validation and compute metrics
    metrics = {
        "accuracy": Accuracy(device=utils.device if distributed else None),
        "loss": Loss(criterion, device=utils.device if distributed else None),
    }

    # We define two evaluators as they wont have exactly similar roles:
    # - `evaluator` will save the best model based on validation score
    evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True)
    train_evaluator = create_supervised_evaluator(model, metrics=metrics, device=utils.device, non_blocking=True)

    def run_validation(engine):
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(Events.EPOCH_STARTED(every=config["validate_every"]), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        # Setup progress bar on evaluation engines
        if config["display_iters"]:
            ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False, desc="Test evaluation").attach(evaluator)

        # Let's log metrics of `train_evaluator` stored in `train_evaluator.state.metrics` when validation run is done
        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train", metric_names="all", global_step_transform=global_step_from_engine(trainer)
            ),
            event_name=Events.COMPLETED,
        )

        # Let's log metrics of `evaluator` stored in `evaluator.state.metrics` when validation run is done
        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test", metric_names="all", global_step_transform=global_step_from_engine(trainer)
            ),
            event_name=Events.COMPLETED,
        )

        # Store 3 best models by validation accuracy:
        common.save_best_model_by_val_score(
            output_path, evaluator, model=model, metric_name="accuracy", n_saved=3, trainer=trainer, tag="test"
        )

        # Optionally log model gradients
        if config["log_model_grads_every"] is not None:
            tb_logger.attach(
                trainer,
                log_handler=GradsHistHandler(model, tag=model.__class__.__name__),
                event_name=Events.ITERATION_COMPLETED(every=config["log_model_grads_every"]),
            )

    # In order to check training resuming we can emulate a crash
    if config["crash_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"]))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(engine.state.iteration))

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()
Exemple #21
0
def run(output_dir, config):
    device = torch.device("cuda" if args.use_cuda else "cpu")

    torch.manual_seed(config['seed'])
    np.random.seed(config['seed'])

    # Rescale batch_size and num_workers
    ngpus_per_node = 1
    batch_size = config['batch_size']
    num_workers = int(
        (config['num_workers'] + ngpus_per_node - 1) / ngpus_per_node)

    (train_loader, test_loader,
     mislabeled_train_loader) = get_train_test_loaders(
         path=config['data_path'],
         batch_size=batch_size,
         num_workers=num_workers,
         random_seed=config['seed'],
         random_labels_fraction=config['random_labels_fraction'],
     )

    model = get_mnist_model(args, device)

    optimizer = AdamFlexibleWeightDecay(
        model.parameters(),
        lr=config['init_lr'],
        weight_decay_order=config['weight_decay_order'],
        weight_decay=config['weight_decay'])

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_loader)
    lr_scheduler = MultiStepLR(optimizer,
                               milestones=[le * config['epochs'] * 3 // 4],
                               gamma=0.1)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(y, device=device, non_blocking=non_blocking))

    def process_function(unused_engine, batch):
        x, y = _prepare_batch(batch, device=device, non_blocking=True)
        model.train()
        optimizer.zero_grad()
        y_pred = model(x)

        if config['agreement_threshold'] > 0.0:
            # The "batch_size" in this function refers to the batch size per env
            # Since we treat every example as one env, we should set the parameter
            # n_agreement_envs equal to batch size
            mean_loss, masks = and_mask_utils.get_grads(
                agreement_threshold=config['agreement_threshold'],
                batch_size=1,
                loss_fn=criterion,
                n_agreement_envs=config['batch_size'],
                params=optimizer.param_groups[0]['params'],
                output=y_pred,
                target=y,
                method=args.method,
                scale_grad_inverse_sparsity=config[
                    'scale_grad_inverse_sparsity'],
            )
        else:
            mean_loss = criterion(y_pred, y)
            mean_loss.backward()

        optimizer.step()

        return {}

    trainer = Engine(process_function)
    metric_names = []
    common.setup_common_training_handlers(trainer,
                                          output_path=output_dir,
                                          lr_scheduler=lr_scheduler,
                                          output_names=metric_names,
                                          with_pbar_on_iters=True,
                                          log_every_iters=10)

    tb_logger = TensorboardLogger(log_dir=output_dir)
    tb_logger.attach(trainer,
                     log_handler=OutputHandler(tag="train",
                                               metric_names=metric_names),
                     event_name=Events.ITERATION_COMPLETED)

    metrics = {"accuracy": Accuracy(), "loss": Loss(criterion)}

    test_evaluator = create_supervised_evaluator(model,
                                                 metrics=metrics,
                                                 device=device,
                                                 non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)
    mislabeled_train_evaluator = create_supervised_evaluator(model,
                                                             metrics=metrics,
                                                             device=device,
                                                             non_blocking=True)

    def run_validation(engine):
        if args.use_cuda:
            torch.cuda.synchronize()
        train_evaluator.run(train_loader)
        if config['random_labels_fraction'] > 0.0:
            mislabeled_train_evaluator.run(mislabeled_train_loader)
        test_evaluator.run(test_loader)

    def flush_metrics(engine):
        tb_logger.writer.flush()

    trainer.add_event_handler(Events.EPOCH_STARTED(every=1), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)
    trainer.add_event_handler(Events.EPOCH_COMPLETED, flush_metrics)

    ProgressBar(persist=False, desc="Train evaluation").attach(train_evaluator)
    ProgressBar(persist=False, desc="Test evaluation").attach(test_evaluator)
    ProgressBar(persist=False,
                desc="Train (mislabeled portion) evaluation").attach(
                    mislabeled_train_evaluator)

    tb_logger.attach(
        train_evaluator,
        log_handler=OutputHandler(
            tag="train",
            metric_names=list(metrics.keys()),
            global_step_transform=global_step_from_engine(trainer)),
        event_name=Events.COMPLETED)
    tb_logger.attach(
        test_evaluator,
        log_handler=OutputHandler(
            tag="test",
            metric_names=list(metrics.keys()),
            global_step_transform=global_step_from_engine(trainer)),
        event_name=Events.COMPLETED)
    tb_logger.attach(
        mislabeled_train_evaluator,
        log_handler=OutputHandler(
            tag="train_wrong",
            metric_names=list(metrics.keys()),
            global_step_transform=global_step_from_engine(trainer)),
        event_name=Events.COMPLETED)

    trainer_rng = np.random.RandomState()
    trainer.run(train_loader,
                max_epochs=config['epochs'],
                seed=trainer_rng.randint(2**32))

    tb_logger.close()
Exemple #22
0
def create_supervised_trainer(model, optimizer, criterion, lr_scheduler,
                              train_sampler, config, logger):
    device = idist.device()

    def _update(engine, batch):

        model.train()

        # x, y = batch[0], batch[1]
        (imgs, targets) = batch

        # if imgs.device != device:
        #    imgs = imgs.to(device, non_blocking=True)
        #    target = target.to(device, non_blocking=True)

        # model.train()
        # (imgs, targets) = batch
        imgs = imgs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        # targets = [target.to(device, non_blocking=True) for target in targets
        #            ]  #if torch.cuda.device_count() >= 1 else targets

        outputs = model(imgs)
        # print(outputs.shape)
        # print(targets.shape)
        loss = criterion(outputs, targets)

        # dist_metrics = [reduce_metric_dict(me) for me in _metrics]

        # Compute gradient
        optimizer.zero_grad()
        # loss = sum(total_loss)
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        acc1, acc5 = utils.accuracy(outputs, targets, topk=(1, 5))
        if config["log_every_iters"] > 0 and (
                engine.state.iteration - 1) % config["log_every_iters"] == 0:
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss
        '''
        if idist.get_rank() == 0:
            print(acc1)
            print(acc5)
            print(batch_loss)
        '''
        return {
            "batch loss": batch_loss,
        }

    trainer = Engine(_update)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        logger.info("Resume from a checkpoint: {}".format(
            checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Exemple #23
0
def run(output_path, config):
    device = "cuda"

    local_rank = config["local_rank"]
    distributed = backend is not None
    if distributed:
        torch.cuda.set_device(local_rank)
        device = "cuda"
    rank = dist.get_rank() if distributed else 0

    torch.manual_seed(config["seed"] + rank)

    # Rescale batch_size and num_workers
    ngpus_per_node = torch.cuda.device_count()
    ngpus = dist.get_world_size() if distributed else 1
    batch_size = config["batch_size"] // ngpus
    num_workers = int(
        (config["num_workers"] + ngpus_per_node - 1) / ngpus_per_node)

    train_loader, test_loader = get_train_test_loaders(
        path=config["data_path"],
        batch_size=batch_size,
        distributed=distributed,
        num_workers=num_workers,
    )

    model = get_model(config["model"])
    model = model.to(device)

    if distributed:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[
                local_rank,
            ],
            output_device=local_rank,
        )

    optimizer = optim.SGD(
        model.parameters(),
        lr=config["learning_rate"],
        momentum=config["momentum"],
        weight_decay=config["weight_decay"],
        nesterov=True,
    )

    criterion = nn.CrossEntropyLoss().to(device)

    le = len(train_loader)
    milestones_values = [
        (0, 0.0),
        (le * config["num_warmup_epochs"], config["learning_rate"]),
        (le * config["num_epochs"], 0.0),
    ]
    lr_scheduler = PiecewiseLinear(optimizer,
                                   param_name="lr",
                                   milestones_values=milestones_values)

    def _prepare_batch(batch, device, non_blocking):
        x, y = batch
        return (
            convert_tensor(x, device=device, non_blocking=non_blocking),
            convert_tensor(y, device=device, non_blocking=non_blocking),
        )

    def process_function(engine, batch):

        x, y = _prepare_batch(batch, device=device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(process_function)
    train_sampler = train_loader.sampler if distributed else None
    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler,
    }
    metric_names = [
        "batch loss",
    ]
    common.setup_common_training_handlers(
        trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        output_path=output_path,
        lr_scheduler=lr_scheduler,
        output_names=metric_names,
        with_pbar_on_iters=config["display_iters"],
        log_every_iters=10,
    )

    if rank == 0:
        tb_logger = TensorboardLogger(log_dir=output_path)
        tb_logger.attach(
            trainer,
            log_handler=OutputHandler(tag="train", metric_names=metric_names),
            event_name=Events.ITERATION_COMPLETED,
        )
        tb_logger.attach(
            trainer,
            log_handler=OptimizerParamsHandler(optimizer, param_name="lr"),
            event_name=Events.ITERATION_STARTED,
        )

    metrics = {
        "accuracy": Accuracy(device=device if distributed else None),
        "loss": Loss(criterion, device=device if distributed else None),
    }

    evaluator = create_supervised_evaluator(model,
                                            metrics=metrics,
                                            device=device,
                                            non_blocking=True)
    train_evaluator = create_supervised_evaluator(model,
                                                  metrics=metrics,
                                                  device=device,
                                                  non_blocking=True)

    def run_validation(engine):
        torch.cuda.synchronize()
        train_evaluator.run(train_loader)
        evaluator.run(test_loader)

    trainer.add_event_handler(
        Events.EPOCH_STARTED(every=config["validate_every"]), run_validation)
    trainer.add_event_handler(Events.COMPLETED, run_validation)

    if rank == 0:
        if config["display_iters"]:
            ProgressBar(persist=False,
                        desc="Train evaluation").attach(train_evaluator)
            ProgressBar(persist=False,
                        desc="Test evaluation").attach(evaluator)

        tb_logger.attach(
            train_evaluator,
            log_handler=OutputHandler(
                tag="train",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        tb_logger.attach(
            evaluator,
            log_handler=OutputHandler(
                tag="test",
                metric_names=list(metrics.keys()),
                global_step_transform=global_step_from_engine(trainer),
            ),
            event_name=Events.COMPLETED,
        )

        # Store the best model by validation accuracy:
        common.save_best_model_by_val_score(
            output_path,
            evaluator,
            model=model,
            metric_name="accuracy",
            n_saved=3,
            trainer=trainer,
            tag="test",
        )

        if config["log_model_grads_every"] is not None:
            tb_logger.attach(
                trainer,
                log_handler=GradsHistHandler(model,
                                             tag=model.__class__.__name__),
                event_name=Events.ITERATION_COMPLETED(
                    every=config["log_model_grads_every"]),
            )

    if config["crash_iteration"] is not None:

        @trainer.on(Events.ITERATION_STARTED(once=config["crash_iteration"]))
        def _(engine):
            raise Exception("STOP at iteration: {}".format(
                engine.state.iteration))

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        print("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix())
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    try:
        trainer.run(train_loader, max_epochs=config["num_epochs"])
    except Exception as e:
        import traceback

        print(traceback.format_exc())

    if rank == 0:
        tb_logger.close()
Exemple #24
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler, config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    with_amp = config["with_amp"]
    scaler = GradScaler(enabled=with_amp)

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()

        with autocast(enabled=with_amp):
            y_pred = model(x)
            loss = criterion(y_pred, y)

        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        return {
            "batch loss": loss.item(),
        }

    trainer = Engine(train_step)
    trainer.logger = logger

    to_save = {"trainer": trainer, "model": model, "optimizer": optimizer, "lr_scheduler": lr_scheduler}
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), f"Checkpoint '{checkpoint_fp.as_posix()}' is not found"
        logger.info(f"Resume from a checkpoint: {checkpoint_fp.as_posix()}")
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Exemple #25
0
def create_trainer(model, optimizer, criterion, lr_scheduler, train_sampler,
                   config, logger):

    device = idist.device()

    # Setup Ignite trainer:
    # - let's define training step
    # - add other common handlers:
    #    - TerminateOnNan,
    #    - handler to setup learning rate scheduling,
    #    - ModelCheckpoint
    #    - RunningAverage` on `train_step` output
    #    - Two progress bars on epochs and optionally on iterations

    def train_step(engine, batch):

        x, y = batch[0], batch[1]

        if x.device != device:
            x = x.to(device, non_blocking=True)
            y = y.to(device, non_blocking=True)

        model.train()
        # Supervised part
        y_pred = model(x)
        loss = criterion(y_pred, y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # This can be helpful for XLA to avoid performance slow down if fetch loss.item() every iteration
        if config["log_every_iters"] > 0 and (
                engine.state.iteration - 1) % config["log_every_iters"] == 0:
            batch_loss = loss.item()
            engine.state.saved_batch_loss = batch_loss
        else:
            batch_loss = engine.state.saved_batch_loss

        return {
            "batch loss": batch_loss,
        }

    trainer = Engine(train_step)
    trainer.state.saved_batch_loss = -1.0
    trainer.state_dict_user_keys.append("saved_batch_loss")
    trainer.logger = logger

    to_save = {
        "trainer": trainer,
        "model": model,
        "optimizer": optimizer,
        "lr_scheduler": lr_scheduler
    }
    metric_names = [
        "batch loss",
    ]

    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        save_every_iters=config["checkpoint_every"],
        save_handler=get_save_handler(config),
        lr_scheduler=lr_scheduler,
        output_names=metric_names if config["log_every_iters"] > 0 else None,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    resume_from = config["resume_from"]
    if resume_from is not None:
        checkpoint_fp = Path(resume_from)
        assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
            checkpoint_fp.as_posix())
        logger.info("Resume from a checkpoint: {}".format(
            checkpoint_fp.as_posix()))
        checkpoint = torch.load(checkpoint_fp.as_posix(), map_location="cpu")
        Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    return trainer
Exemple #26
0
def create_trainer(
    train_step,
    output_names,
    model,
    ema_model,
    optimizer,
    lr_scheduler,
    supervised_train_loader,
    test_loader,
    cfg,
    logger,
    cta=None,
    unsup_train_loader=None,
    cta_probe_loader=None,
):

    trainer = Engine(train_step)
    trainer.logger = logger

    output_path = os.getcwd()

    to_save = {
        "model": model,
        "ema_model": ema_model,
        "optimizer": optimizer,
        "trainer": trainer,
        "lr_scheduler": lr_scheduler,
    }
    if cta is not None:
        to_save["cta"] = cta

    common.setup_common_training_handlers(
        trainer,
        train_sampler=supervised_train_loader.sampler,
        to_save=to_save,
        save_every_iters=cfg.solver.checkpoint_every,
        output_path=output_path,
        output_names=output_names,
        lr_scheduler=lr_scheduler,
        with_pbars=False,
        clear_cuda_cache=False,
    )

    ProgressBar(persist=False).attach(
        trainer, metric_names="all", event_name=Events.ITERATION_COMPLETED
    )

    unsupervised_train_loader_iter = None
    if unsup_train_loader is not None:
        unsupervised_train_loader_iter = cycle(unsup_train_loader)

    cta_probe_loader_iter = None
    if cta_probe_loader is not None:
        cta_probe_loader_iter = cycle(cta_probe_loader)

    # Setup handler to prepare data batches
    @trainer.on(Events.ITERATION_STARTED)
    def prepare_batch(e):
        sup_batch = e.state.batch
        e.state.batch = {
            "sup_batch": sup_batch,
        }
        if unsupervised_train_loader_iter is not None:
            unsup_batch = next(unsupervised_train_loader_iter)
            e.state.batch["unsup_batch"] = unsup_batch

        if cta_probe_loader_iter is not None:
            cta_probe_batch = next(cta_probe_loader_iter)
            cta_probe_batch["policy"] = [
                deserialize(p) for p in cta_probe_batch["policy"]
            ]
            e.state.batch["cta_probe_batch"] = cta_probe_batch

    # Setup handler to update EMA model
    @trainer.on(Events.ITERATION_COMPLETED, cfg.ema_decay)
    def update_ema_model(ema_decay):
        # EMA on parametes
        for ema_param, param in zip(ema_model.parameters(), model.parameters()):
            ema_param.data.mul_(ema_decay).add_(param.data, alpha=1.0 - ema_decay)

    # Setup handlers for debugging
    if cfg.debug:

        @trainer.on(Events.STARTED | Events.ITERATION_COMPLETED(every=100))
        @idist.one_rank_only()
        def log_weights_norms():
            wn = []
            ema_wn = []
            for ema_param, param in zip(ema_model.parameters(), model.parameters()):
                wn.append(torch.mean(param.data))
                ema_wn.append(torch.mean(ema_param.data))

            msg = "\n\nWeights norms"
            msg += "\n- Raw model: {}".format(
                to_list_str(torch.tensor(wn[:10] + wn[-10:]))
            )
            msg += "\n- EMA model: {}\n".format(
                to_list_str(torch.tensor(ema_wn[:10] + ema_wn[-10:]))
            )
            logger.info(msg)

            rmn = []
            rvar = []
            ema_rmn = []
            ema_rvar = []
            for m1, m2 in zip(model.modules(), ema_model.modules()):
                if isinstance(m1, nn.BatchNorm2d) and isinstance(m2, nn.BatchNorm2d):
                    rmn.append(torch.mean(m1.running_mean))
                    rvar.append(torch.mean(m1.running_var))
                    ema_rmn.append(torch.mean(m2.running_mean))
                    ema_rvar.append(torch.mean(m2.running_var))

            msg = "\n\nBN buffers"
            msg += "\n- Raw mean: {}".format(to_list_str(torch.tensor(rmn[:10])))
            msg += "\n- Raw var: {}".format(to_list_str(torch.tensor(rvar[:10])))
            msg += "\n- EMA mean: {}".format(to_list_str(torch.tensor(ema_rmn[:10])))
            msg += "\n- EMA var: {}\n".format(to_list_str(torch.tensor(ema_rvar[:10])))
            logger.info(msg)

        # TODO: Need to inspect a bug
        # if idist.get_rank() == 0:
        #     from ignite.contrib.handlers import ProgressBar
        #
        #     profiler = BasicTimeProfiler()
        #     profiler.attach(trainer)
        #
        #     @trainer.on(Events.ITERATION_COMPLETED(every=200))
        #     def log_profiling(_):
        #         results = profiler.get_results()
        #         profiler.print_results(results)

    # Setup validation engine
    metrics = {
        "accuracy": Accuracy(),
    }

    if not (idist.has_xla_support and idist.backend() == idist.xla.XLA_TPU):
        metrics.update({
            "precision": Precision(average=False),
            "recall": Recall(average=False),
        })

    eval_kwargs = dict(
        metrics=metrics,
        prepare_batch=sup_prepare_batch,
        device=idist.device(),
        non_blocking=True,
    )

    evaluator = create_supervised_evaluator(model, **eval_kwargs)
    ema_evaluator = create_supervised_evaluator(ema_model, **eval_kwargs)

    def log_results(epoch, max_epochs, metrics, ema_metrics):
        msg1 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in metrics.items()]
        )
        msg2 = "\n".join(
            ["\t{:16s}: {}".format(k, to_list_str(v)) for k, v in ema_metrics.items()]
        )
        logger.info(
            "\nEpoch {}/{}\nRaw:\n{}\nEMA:\n{}\n".format(epoch, max_epochs, msg1, msg2)
        )
        if cta is not None:
            logger.info("\n" + stats(cta))

    @trainer.on(
        Events.EPOCH_COMPLETED(every=cfg.solver.validate_every)
        | Events.STARTED
        | Events.COMPLETED
    )
    def run_evaluation():
        evaluator.run(test_loader)
        ema_evaluator.run(test_loader)
        log_results(
            trainer.state.epoch,
            trainer.state.max_epochs,
            evaluator.state.metrics,
            ema_evaluator.state.metrics,
        )

    # setup TB logging
    if idist.get_rank() == 0:
        tb_logger = common.setup_tb_logging(
            output_path,
            trainer,
            optimizers=optimizer,
            evaluators={"validation": evaluator, "ema validation": ema_evaluator},
            log_every_iters=15,
        )
        if cfg.online_exp_tracking.wandb:
            from ignite.contrib.handlers import WandBLogger

            wb_dir = Path("/tmp/output-fixmatch-wandb")
            if not wb_dir.exists():
                wb_dir.mkdir()

            _ = WandBLogger(
                project="fixmatch-pytorch",
                name=cfg.name,
                config=cfg,
                sync_tensorboard=True,
                dir=wb_dir.as_posix(),
                reinit=True,
            )

    resume_from = cfg.solver.resume_from
    if resume_from is not None:
        resume_from = list(Path(resume_from).rglob("training_checkpoint*.pt*"))
        if len(resume_from) > 0:
            # get latest
            checkpoint_fp = max(resume_from, key=lambda p: p.stat().st_mtime)
            assert checkpoint_fp.exists(), "Checkpoint '{}' is not found".format(
                checkpoint_fp.as_posix()
            )
            logger.info("Resume from a checkpoint: {}".format(checkpoint_fp.as_posix()))
            checkpoint = torch.load(checkpoint_fp.as_posix())
            Checkpoint.load_objects(to_load=to_save, checkpoint=checkpoint)

    @trainer.on(Events.COMPLETED)
    def release_all_resources():
        nonlocal unsupervised_train_loader_iter, cta_probe_loader_iter

        if idist.get_rank() == 0:
            tb_logger.close()

        if unsupervised_train_loader_iter is not None:
            unsupervised_train_loader_iter = None

        if cta_probe_loader_iter is not None:
            cta_probe_loader_iter = None

    return trainer
def get_handlers(
    config: Any,
    model: Module,
    trainer: Engine,
    evaluator: Engine,
    metric_name: str,
    es_metric_name: str,
    train_sampler: Optional[DistributedSampler] = None,
    to_save: Optional[Mapping] = None,
    lr_scheduler: Optional[LRScheduler] = None,
    output_names: Optional[Iterable[str]] = None,
    **kwargs: Any,
) -> Union[Tuple[Checkpoint, EarlyStopping, Timer], Tuple[None, None, None]]:
    """Get best model, earlystopping, timer handlers.

    Parameters
    ----------
    config
        Config object for setting up handlers

    `config` has to contain
    - `output_dir`: output path to indicate where to_save objects are stored
    - `save_every_iters`: saving iteration interval
    - `n_saved`: number of best models to store
    - `log_every_iters`: logging interval for iteration progress bar and `GpuInfo` if true
    - `with_pbars`: show two progress bars
    - `with_pbar_on_iters`: show iteration-wise progress bar
    - `stop_on_nan`: Stop the training if engine output contains NaN/inf values
    - `clear_cuda_cache`: clear cuda cache every end of epoch
    - `with_gpu_stats`: show GPU information: used memory percentage, gpu utilization percentage values
    - `patience`: number of events to wait if no improvement and then stop the training
    - `limit_sec`: maximum time before training terminates in seconds

    model
        best model to save
    trainer
        the engine used for training
    evaluator
        the engine used for evaluation
    metric_name
        evaluation metric to save the best model
    es_metric_name
        evaluation metric to early stop the model
    train_sampler
        distributed training sampler to call `set_epoch`
    to_save
        objects to save during training
    lr_scheduler
        learning rate scheduler as native torch LRScheduler or ignite’s parameter scheduler
    output_names
        list of names associated with `trainer`'s process_function output dictionary
    kwargs
        keyword arguments passed to Checkpoint handler

    Returns
    -------
    best_model_handler, es_handler, timer_handler
    """

    best_model_handler, es_handler, timer_handler = None, None, None

    # https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.setup_common_training_handlers
    # kwargs can be passed to save the model based on training stats
    # like score_name, score_function
    common.setup_common_training_handlers(
        trainer=trainer,
        train_sampler=train_sampler,
        to_save=to_save,
        lr_scheduler=lr_scheduler,
        output_names=output_names,
        output_path=config.output_dir / 'checkpoints',
        save_every_iters=config.save_every_iters,
        n_saved=config.n_saved,
        log_every_iters=config.log_every_iters,
        with_pbars=config.with_pbars,
        with_pbar_on_iters=config.with_pbar_on_iters,
        stop_on_nan=config.stop_on_nan,
        clear_cuda_cache=config.clear_cuda_cache,
        with_gpu_stats=config.with_gpu_stats,
        **kwargs,
    )
    {% if save_best_model_by_val_score %}

    # https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.save_best_model_by_val_score
    best_model_handler = common.save_best_model_by_val_score(
        output_path=config.output_dir / 'checkpoints',
        evaluator=evaluator,
        model=model,
        metric_name=metric_name,
        n_saved=config.n_saved,
        trainer=trainer,
        tag='eval',
    )
    {% endif %}
    {% if add_early_stopping_by_val_score %}

    # https://pytorch.org/ignite/contrib/engines.html#ignite.contrib.engines.common.add_early_stopping_by_val_score
    es_handler = common.add_early_stopping_by_val_score(
        patience=config.patience,
        evaluator=evaluator,
        trainer=trainer,
        metric_name=es_metric_name,
    )
    {% endif %}
    {% if setup_timer %}

    # https://pytorch.org/ignite/handlers.html#ignite.handlers.Timer
    # measure the average time to process a single batch of samples
    # Events for that are - ITERATION_STARTED and ITERATION_COMPLETED
    # you can replace with the events you want to measure
    timer_handler = Timer(average=True)
    timer_handler.attach(
        engine=trainer,
        start=Events.EPOCH_STARTED,
        resume=Events.ITERATION_STARTED,
        pause=Events.ITERATION_COMPLETED,
        step=Events.ITERATION_COMPLETED,
    )
    {% endif %}
    {% if setup_timelimit %}

    # training will terminate if training time exceed `limit_sec`.
    trainer.add_event_handler(
        Events.ITERATION_COMPLETED, TimeLimit(limit_sec=config.limit_sec)
    )
    {% endif %}
    return best_model_handler, es_handler, timer_handler