def run(conf: DictConfig, local_rank=0, distributed=False):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.general.seed)

    if distributed:
        rank = dist.get_rank()
        num_replicas = dist.get_world_size()
        torch.cuda.set_device(local_rank)
    else:
        rank = 0
        num_replicas = 1
        torch.cuda.set_device(conf.general.gpu)
    device = torch.device('cuda')
    loader_args = dict()
    master_node = rank == 0

    if master_node:
        print(conf.pretty())
    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args = dict(rank=rank, num_replicas=num_replicas)

    train_dl = create_train_loader(conf.data, **loader_args)

    if epoch_length < 1:
        epoch_length = len(train_dl)

    metric_names = list(conf.logging.stats)
    metrics = create_metrics(metric_names, device if distributed else None)

    G = instantiate(conf.model.G).to(device)
    D = instantiate(conf.model.D).to(device)
    G_loss = instantiate(conf.loss.G).to(device)
    D_loss = instantiate(conf.loss.D).to(device)
    G_opt = instantiate(conf.optim.G, G.parameters())
    D_opt = instantiate(conf.optim.D, D.parameters())
    G_ema = None

    if master_node and conf.G_smoothing.enabled:
        G_ema = instantiate(conf.model.G)
        if not conf.G_smoothing.use_cpu:
            G_ema = G_ema.to(device)
        G_ema.load_state_dict(G.state_dict())
        G_ema.requires_grad_(False)

    to_save = {
        'G': G,
        'D': D,
        'G_loss': G_loss,
        'D_loss': D_loss,
        'G_opt': G_opt,
        'D_opt': D_opt,
        'G_ema': G_ema
    }

    if master_node and conf.logging.model:
        logging.info(G)
        logging.info(D)

    if distributed:
        ddp_kwargs = dict(device_ids=[
            local_rank,
        ], output_device=local_rank)
        G = torch.nn.parallel.DistributedDataParallel(G, **ddp_kwargs)
        D = torch.nn.parallel.DistributedDataParallel(D, **ddp_kwargs)

    train_options = {
        'train': dict(conf.train),
        'snapshot': dict(conf.snapshots),
        'smoothing': dict(conf.G_smoothing),
        'distributed': distributed
    }
    bs_dl = int(conf.data.loader.batch_size) * num_replicas
    bs_eff = conf.train.batch_size
    if bs_eff % bs_dl:
        raise AttributeError(
            "Effective batch size should be divisible by data-loader batch size "
            "multiplied by number of devices in use"
        )  # until there is no special bs for master node...
    upd_interval = max(bs_eff // bs_dl, 1)
    train_options['train']['update_interval'] = upd_interval
    if epoch_length < len(train_dl):
        # ideally epoch_length should be tied to the effective batch_size only
        # and the ignite trainer counts data-loader iterations
        epoch_length *= upd_interval

    train_loop, sample_images = create_train_closures(G,
                                                      D,
                                                      G_loss,
                                                      D_loss,
                                                      G_opt,
                                                      D_opt,
                                                      G_ema=G_ema,
                                                      device=device,
                                                      options=train_options)
    trainer = create_trainer(train_loop, metrics, device, num_replicas)
    to_save['trainer'] = trainer

    every_iteration = Events.ITERATION_COMPLETED
    trainer.add_event_handler(every_iteration, TerminateOnNan())

    cp = conf.checkpoints
    pbar = None

    if master_node:
        log_freq = conf.logging.iter_freq
        log_event = Events.ITERATION_COMPLETED(every=log_freq)
        pbar = ProgressBar(persist=False)
        trainer.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
        trainer.add_event_handler(log_event, log_iter, pbar, log_freq)
        trainer.add_event_handler(Events.EPOCH_COMPLETED, log_epoch)
        pbar.attach(trainer, metric_names=metric_names)
        setup_checkpoints(trainer, to_save, epoch_length, conf)
        setup_snapshots(trainer, sample_images, conf)

    if 'load' in cp.keys() and cp.load is not None:
        if master_node:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)
        Checkpoint.load_objects(to_load=to_save,
                                checkpoint=torch.load(cp.load,
                                                      map_location=device))

    try:
        trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        logging.error(traceback.format_exc())
    if pbar is not None:
        pbar.close()
Beispiel #2
0
def main(parser_args):
    """Main function to create trainer engine, add handlers to train and validation engines.
    Then runs train engine to perform training and validation.

    Args:
        parser_args (dict): parsed arguments
    """
    dataloader_train, dataloader_validation = get_dataloaders(parser_args)
    criterion = nn.CrossEntropyLoss()

    unet = SphericalUNet(parser_args.pooling_class, parser_args.n_pixels,
                         parser_args.depth, parser_args.laplacian_type,
                         parser_args.kernel_size)
    # unet = torch.jit.script(unet)
    unet, device = init_device(parser_args.device, unet)
    lr = parser_args.learning_rate
    optimizer = optim.Adam(unet.parameters(), lr=lr)
    print(sum(p.numel() for p in unet.parameters() if p.requires_grad))

    def trainer(engine, batch):
        """Train Function to define train engine.
        Called for every batch of the train engine, for each epoch.

        Args:
            engine (ignite.engine): train engine
            batch (:obj:`torch.utils.data.dataloader`): batch from train dataloader

        Returns:
            :obj:`torch.tensor` : train loss for that batch and epoch
        """
        unet.train()
        optimizer.zero_grad()

        data, labels = batch.x, batch.y
        labels = labels.to(device)
        data = data.to(device)
        output = unet(data)

        B, V, C = output.shape
        B_labels, V_labels, C_labels = labels.shape
        output = output.view(B * V, C)
        labels = labels.view(B_labels * V_labels, C_labels).max(1)[1]

        loss = criterion(output, labels)
        loss.backward()
        optimizer.step()
        return {'loss': loss.item()}

    writer = SummaryWriter(parser_args.tensorboard_path)

    engine_train = Engine(trainer)

    RunningAverage(output_transform=lambda x: x['loss']).attach(
        engine_train, 'loss')

    def prepare_batch(batch, device, non_blocking):
        """Prepare batch for training: pass to a device with options.

        """
        return (
            convert_tensor(batch.x, device=device, non_blocking=non_blocking),
            convert_tensor(batch.y, device=device, non_blocking=non_blocking),
        )

    engine_validate = create_supervised_evaluator(
        model=unet,
        metrics={"AP": EpochMetric(average_precision_compute_fn)},
        device=device,
        output_transform=validate_output_transform,
        prepare_batch=prepare_batch)

    engine_train.add_event_handler(
        Events.EPOCH_STARTED,
        lambda x: print("Starting Epoch: {}".format(x.state.epoch)))
    engine_train.add_event_handler(Events.ITERATION_COMPLETED,
                                   TerminateOnNan())

    @engine_train.on(Events.EPOCH_COMPLETED)
    def epoch_validation(engine):
        """Handler to run the validation engine at the end of the train engine's epoch.

        Args:
            engine (ignite.engine): train engine
        """
        print("beginning validation epoch")
        engine_validate.run(dataloader_validation)

    reduce_lr_plateau = ReduceLROnPlateau(
        optimizer,
        mode=parser_args.reducelronplateau_mode,
        factor=parser_args.reducelronplateau_factor,
        patience=parser_args.reducelronplateau_patience,
    )

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def update_reduce_on_plateau(engine):
        """Handler to reduce the learning rate on plateau at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        reduce_lr_plateau.step(mean_average_precision)

    @engine_validate.on(Events.EPOCH_COMPLETED)
    def save_epoch_results(engine):
        """Handler to save the metrics at the end of the validation engine's epoch

        Args:
            engine (ignite.engine): validation engine
        """
        ap = engine.state.metrics["AP"]
        mean_average_precision = np.mean(ap[1:])
        print("Average precisions:", ap)
        print("mAP:", mean_average_precision)
        writer.add_scalars(
            "metrics",
            {
                "mean average precision (AR+TC)": mean_average_precision,
                "AR average precision": ap[2],
                "TC average precision": ap[1]
            },
            engine_train.state.epoch,
        )
        writer.close()

    step_scheduler = StepLR(optimizer,
                            step_size=parser_args.steplr_step_size,
                            gamma=parser_args.steplr_gamma)
    scheduler = create_lr_scheduler_with_warmup(
        step_scheduler,
        warmup_start_value=parser_args.warmuplr_warmup_start_value,
        warmup_end_value=parser_args.warmuplr_warmup_end_value,
        warmup_duration=parser_args.warmuplr_warmup_duration,
    )
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, scheduler)

    earlystopper = EarlyStopping(
        patience=parser_args.earlystopping_patience,
        score_function=lambda x: -x.state.metrics["AP"][1],
        trainer=engine_train)
    engine_validate.add_event_handler(Events.EPOCH_COMPLETED, earlystopper)

    add_tensorboard(engine_train,
                    optimizer,
                    unet,
                    log_dir=parser_args.tensorboard_path)

    pbar = ProgressBar()
    pbar.attach(engine_train, metric_names=['loss'])

    engine_train.run(dataloader_train, max_epochs=parser_args.n_epochs)

    pbar.close()
    torch.save(unet.state_dict(),
               parser_args.model_save_path + "unet_state.pt")
Beispiel #3
0
def run(conf: DictConfig, local_rank=0, distributed=False):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.seed)

    if distributed:
        rank = dist.get_rank()
        num_replicas = dist.get_world_size()
        torch.cuda.set_device(local_rank)
    else:
        rank = 0
        num_replicas = 1
        torch.cuda.set_device(conf.gpu)
    device = torch.device('cuda')
    loader_args = dict(mean=conf.data.mean, std=conf.data.std)
    master_node = rank == 0

    if master_node:
        print(conf.pretty())
    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args["rank"] = rank
        loader_args["num_replicas"] = num_replicas

    train_dl = create_train_loader(conf.data.train, **loader_args)
    valid_dl = create_val_loader(conf.data.val, **loader_args)

    if epoch_length < 1:
        epoch_length = len(train_dl)

    model = instantiate(conf.model).to(device)
    model_ema, update_ema = setup_ema(conf,
                                      model,
                                      device=device,
                                      master_node=master_node)
    optim = build_optimizer(conf.optim, model)

    scheduler_kwargs = dict()
    if "schedule.OneCyclePolicy" in conf.lr_scheduler["class"]:
        scheduler_kwargs["cycle_steps"] = epoch_length
    lr_scheduler: Scheduler = instantiate(conf.lr_scheduler, optim,
                                          **scheduler_kwargs)

    use_amp = False
    if conf.use_apex:
        import apex
        from apex import amp
        logging.debug("Nvidia's Apex package is available")

        model, optim = amp.initialize(model, optim, **conf.amp)
        use_amp = True
        if master_node:
            logging.info("Using AMP with opt_level={}".format(
                conf.amp.opt_level))
    else:
        apex, amp = None, None

    to_save = dict(model=model, optim=optim)
    if use_amp:
        to_save["amp"] = amp
    if model_ema is not None:
        to_save["model_ema"] = model_ema

    if master_node and conf.logging.model:
        logging.info(model)

    if distributed:
        sync_bn = conf.distributed.sync_bn
        if apex is not None:
            if sync_bn:
                model = apex.parallel.convert_syncbn_model(model)
            model = apex.parallel.distributed.DistributedDataParallel(
                model, delay_allreduce=True)
        else:
            if sync_bn:
                model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
            model = torch.nn.parallel.DistributedDataParallel(
                model, device_ids=[
                    local_rank,
                ], output_device=local_rank)

    upd_interval = conf.optim.step_interval
    ema_interval = conf.smoothing.interval_it * upd_interval
    clip_grad = conf.optim.clip_grad

    _handle_batch_train = build_process_batch_func(conf.data,
                                                   stage="train",
                                                   device=device)
    _handle_batch_val = build_process_batch_func(conf.data,
                                                 stage="val",
                                                 device=device)

    def _update(eng: Engine, batch: Batch) -> FloatDict:
        model.train()
        batch = _handle_batch_train(batch)
        losses: Dict = model(*batch)
        stats = {k: v.item() for k, v in losses.items()}
        loss = losses["loss"]
        del losses

        if use_amp:
            with amp.scale_loss(loss, optim) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        it = eng.state.iteration
        if not it % upd_interval:
            if clip_grad > 0:
                params = amp.master_params(
                    optim) if use_amp else model.parameters()
                torch.nn.utils.clip_grad_norm_(params, clip_grad)
            optim.step()
            optim.zero_grad()
            lr_scheduler.step_update(it)

            if not it % ema_interval:
                update_ema()

            eng.state.lr = optim.param_groups[0]["lr"]

        return stats

    calc_map = conf.validate.calc_map
    min_score = conf.validate.get("min_score", -1)

    model_val = model
    if conf.train.skip and model_ema is not None:
        model_val = model_ema.to(device)

    def _validate(eng: Engine, batch: Batch) -> FloatDict:
        model_val.eval()
        images, targets = _handle_batch_val(batch)

        with torch.no_grad():
            out: Dict = model_val(images, targets)

        pred_boxes = out.pop("detections")
        stats = {k: v.item() for k, v in out.items()}

        if calc_map:
            pred_boxes = pred_boxes.detach().cpu().numpy()
            true_boxes = targets['bbox'].cpu().numpy()
            img_scale = targets['img_scale'].cpu().numpy()
            # yxyx -> xyxy
            true_boxes = true_boxes[:, :, [1, 0, 3, 2]]
            # xyxy -> xywh
            true_boxes[:, :, [2, 3]] -= true_boxes[:, :, [0, 1]]
            # scale downsized boxes to match predictions on a full-sized image
            true_boxes *= img_scale[:, None, None]

            scores = []
            for i in range(len(images)):
                mask = pred_boxes[i, :, 4] >= min_score
                s = calculate_image_precision(true_boxes[i],
                                              pred_boxes[i, mask, :4],
                                              thresholds=IOU_THRESHOLDS,
                                              form='coco')
                scores.append(s)
            stats['map'] = np.mean(scores)

        return stats

    train_metric_names = list(conf.logging.out.train)
    train_metrics = create_metrics(train_metric_names,
                                   device if distributed else None)

    val_metric_names = list(conf.logging.out.val)
    if calc_map:
        from utils.metric import calculate_image_precision, IOU_THRESHOLDS
        val_metric_names.append('map')
    val_metrics = create_metrics(val_metric_names,
                                 device if distributed else None)

    trainer = build_engine(_update, train_metrics)
    evaluator = build_engine(_validate, val_metrics)
    to_save['trainer'] = trainer

    every_iteration = Events.ITERATION_COMPLETED
    trainer.add_event_handler(every_iteration, TerminateOnNan())

    if distributed:
        dist_bn = conf.distributed.dist_bn
        if dist_bn in ["reduce", "broadcast"]:
            from timm.utils import distribute_bn

            @trainer.on(Events.EPOCH_COMPLETED)
            def _distribute_bn_stats(eng: Engine):
                reduce = dist_bn == "reduce"
                if master_node:
                    logging.info("Distributing BN stats...")
                distribute_bn(model, num_replicas, reduce)

    sampler = train_dl.sampler
    if isinstance(sampler, (CustomSampler, DistributedSampler)):

        @trainer.on(Events.EPOCH_STARTED)
        def _set_epoch(eng: Engine):
            sampler.set_epoch(eng.state.epoch - 1)

    @trainer.on(Events.EPOCH_COMPLETED)
    def _scheduler_step(eng: Engine):
        # it starts from 1, so we don't need to add 1 here
        ep = eng.state.epoch
        lr_scheduler.step(ep)

    cp = conf.checkpoints
    pbar, pbar_vis = None, None

    if master_node:
        log_interval = conf.logging.interval_it
        log_event = Events.ITERATION_COMPLETED(every=log_interval)
        pbar = ProgressBar(persist=False)
        pbar.attach(trainer, metric_names=train_metric_names)
        pbar.attach(evaluator, metric_names=val_metric_names)

        for engine, name in zip([trainer, evaluator], ['train', 'val']):
            engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
            engine.add_event_handler(log_event,
                                     log_iter,
                                     pbar,
                                     interval_it=log_interval,
                                     name=name)
            engine.add_event_handler(Events.EPOCH_COMPLETED,
                                     log_epoch,
                                     name=name)

        setup_checkpoints(trainer, to_save, epoch_length, conf)

    if 'load' in cp.keys() and cp.load is not None:
        if master_node:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)
        resume_from_checkpoint(to_save, cp, device=device)
        state = trainer.state
        # epoch counter start from 1
        lr_scheduler.step(state.epoch - 1)
        state.max_epochs = epochs

    @trainer.on(Events.EPOCH_COMPLETED(every=conf.validate.interval_ep))
    def _run_validation(eng: Engine):
        if distributed:
            torch.cuda.synchronize(device)
        evaluator.run(valid_dl)

    skip_train = conf.train.skip
    if master_node and conf.visualize.enabled:
        vis_eng = evaluator if skip_train else trainer
        setup_visualizations(vis_eng,
                             model,
                             valid_dl,
                             device,
                             conf,
                             force_run=skip_train)

    try:
        if skip_train:
            evaluator.run(valid_dl)
        else:
            trainer.run(train_dl, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        logging.error(traceback.format_exc())

    for pb in [pbar, pbar_vis]:
        if pb is not None:
            pbar.close()
def run(conf: DictConfig):
    epochs = conf.train.epochs
    epoch_length = conf.train.epoch_length
    torch.manual_seed(conf.general.seed)

    dist_conf = conf.distributed
    local_rank = dist_conf.local_rank
    backend = dist_conf.backend
    distributed = backend is not None
    use_tpu = conf.tpu.enabled

    if use_tpu:
        rank = xm.get_ordinal()
        num_replicas = xm.xrt_world_size()
        device = xm.xla_device()
    else:
        if distributed:
            rank = dist.get_rank()
            num_replicas = dist.get_world_size()
            torch.cuda.set_device(local_rank)
        else:
            rank = 0
            num_replicas = 1
            torch.cuda.set_device(conf.general.gpu)
        device = torch.device('cuda')

    if rank == 0:
        print(conf.pretty())

    if num_replicas > 1:
        epoch_length = epoch_length // num_replicas
        loader_args = dict(rank=rank, num_replicas=num_replicas)
    else:
        loader_args = dict()

    train_dl = create_train_loader(conf.data.train,
                                   epoch_length=epoch_length,
                                   **loader_args)
    valid_dl = create_val_loader(conf.data.val, **loader_args)
    train_sampler = train_dl.sampler

    if epoch_length < 1:
        epoch_length = len(train_dl)

    if use_tpu:
        train_dl = pl.ParallelLoader(train_dl, [device])
        valid_dl = pl.ParallelLoader(valid_dl, [device])

    model = instantiate(conf.model).to(device)
    if distributed:
        model = DistributedDataParallel(model,
                                        device_ids=[
                                            local_rank,
                                        ],
                                        output_device=local_rank)
        model.to_y = model.module.to_y
    if rank == 0 and conf.logging.model:
        print(model)

    loss = instantiate(conf.loss)
    optim = instantiate(conf.optimizer,
                        filter(lambda x: x.requires_grad, model.parameters()))
    metrics = create_metrics(loss.keys(), device if distributed else None)
    build_trainer_fn = create_tpu_trainer if use_tpu else create_trainer
    trainer = build_trainer_fn(model, loss, optim, device, conf, metrics)
    evaluator = create_evaluator(model, loss, device, metrics)

    every_iteration = Events.ITERATION_COMPLETED

    if 'lr_scheduler' in conf.keys():
        # TODO: total_steps is wrong, it works only for one-cycle
        lr_scheduler = instantiate(conf.lr_scheduler,
                                   optim,
                                   total_steps=epoch_length)
        trainer.add_event_handler(every_iteration,
                                  lambda _: lr_scheduler.step())

        if isinstance(lr_scheduler, torch.optim.lr_scheduler.OneCycleLR):
            initial_state = lr_scheduler.state_dict()
            trainer.add_event_handler(
                Events.ITERATION_COMPLETED(every=epoch_length),
                lambda _: lr_scheduler.load_state_dict(initial_state))
    else:
        lr_scheduler = None

    trainer.add_event_handler(every_iteration, TerminateOnNan())

    cp = conf.train.checkpoints
    to_save = {
        'trainer': trainer,
        'model': model.module if distributed else model,
        'optimizer': optim,
        'lr_scheduler': lr_scheduler
    }
    save_path = cp.get('base_dir', os.getcwd())

    if rank == 0:
        log_freq = conf.logging.iter_freq
        log_event = Events.ITERATION_COMPLETED(every=log_freq)
        pbar = ProgressBar(persist=False)

        for engine, name in zip([trainer, evaluator], ['train', 'val']):
            engine.add_event_handler(Events.EPOCH_STARTED, on_epoch_start)
            engine.add_event_handler(log_event, log_iter, trainer, pbar, name,
                                     log_freq)
            engine.add_event_handler(Events.EPOCH_COMPLETED, log_epoch,
                                     trainer, name)
            pbar.attach(engine, metric_names=loss.keys())

        if 'load' in cp.keys() and cp.load:
            logging.info("Resume from a checkpoint: {}".format(cp.load))
            trainer.add_event_handler(Events.STARTED, _upd_pbar_iter_from_cp,
                                      pbar)

        logging.info("Saving checkpoints to {}".format(save_path))

    if rank == 0 or use_tpu:
        max_cp = max(int(cp.get('max_checkpoints', 1)), 1)
        Saver = TpuDiskSaver if use_tpu else DiskSaver
        save = Saver(save_path, create_dir=True, require_empty=True)
        make_checkpoint = Checkpoint(to_save, save, n_saved=max_cp)
        cp_iter = cp.interval_iteration
        cp_epoch = cp.interval_epoch
        if cp_iter > 0:
            save_event = Events.ITERATION_COMPLETED(every=cp_iter)
            trainer.add_event_handler(save_event, make_checkpoint)
        if cp_epoch > 0:
            if cp_iter < 1 or epoch_length % cp_iter:
                save_event = Events.EPOCH_COMPLETED(every=cp_epoch)
                trainer.add_event_handler(save_event, make_checkpoint)

    if 'load' in cp.keys() and cp.load:
        Checkpoint.load_objects(to_load=to_save,
                                checkpoint=torch.load(cp.load,
                                                      map_location=device))

    assert train_sampler is not None
    trainer.add_event_handler(
        Events.EPOCH_STARTED,
        lambda e: train_sampler.set_epoch(e.state.epoch - 1))

    def run_validation(e: Engine):
        if distributed:
            torch.cuda.synchronize(device)
        if use_tpu:
            xm.rendezvous('validate_{}'.format(e.state.iteration))
            valid_it = valid_dl.per_device_loader(device)
            evaluator.run(valid_it, epoch_length=len(valid_dl))
        else:
            evaluator.run(valid_dl)

    eval_event = Events.EPOCH_COMPLETED(every=conf.validate.interval)
    trainer.add_event_handler(eval_event, run_validation)

    try:
        if conf.train.skip:
            evaluator.run(valid_dl)
        else:
            loader = train_dl
            if use_tpu:
                # need to catch StopIteration before ignite, otherwise it will crash
                loader = iter(_regenerate(train_dl, device))
            trainer.run(loader, max_epochs=epochs, epoch_length=epoch_length)
    except Exception as e:
        import traceback
        print(traceback.format_exc())
    if rank == 0:
        pbar.close()