Ejemplo n.º 1
0
def train(flags, model, train_loader, val_loader, loss_fn, score_fn, device, callbacks, is_distributed):
    rank = get_rank()
    world_size = get_world_size()
    torch.backends.cudnn.benchmark = flags.cudnn_benchmark
    torch.backends.cudnn.deterministic = flags.cudnn_deterministic

    optimizer = get_optimizer(model.parameters(), flags)
    if flags.lr_decay_epochs:
        scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                         milestones=flags.lr_decay_epochs,
                                                         gamma=flags.lr_decay_factor)
    scaler = GradScaler()

    model.to(device)
    loss_fn.to(device)
    if flags.normalization == "syncbatchnorm":
        model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
    if is_distributed:
        model = torch.nn.parallel.DistributedDataParallel(model,
                                                          device_ids=[flags.local_rank],
                                                          output_device=flags.local_rank)

    stop_training = False
    for epoch in range(1, flags.epochs + 1):
        cumulative_loss = []
        if epoch <= flags.lr_warmup_epochs and flags.lr_warmup_epochs > 0:
            lr_warmup(optimizer, flags.init_learning_rate, flags.learning_rate, epoch, flags.lr_warmup_epochs)
        mllog_start(key=CONSTANTS.BLOCK_START, sync=True,
                    metadata={CONSTANTS.FIRST_EPOCH_NUM: epoch, CONSTANTS.EPOCH_COUNT: 1})
        mllog_start(key=CONSTANTS.EPOCH_START, metadata={CONSTANTS.EPOCH_NUM: epoch}, sync=True)
        if is_distributed:
            train_loader.sampler.set_epoch(epoch)
            # val_loader.sampler.set_epoch(epoch)

        optimizer.zero_grad()
        accumulated_steps = 0
        for i, batch in enumerate(tqdm(train_loader, disable=(rank != 0) or not flags.verbose)):
            image, label = batch
            image, label = image.to(device), label.to(device)
            for callback in callbacks:
                callback.on_batch_start()

            with autocast(enabled=flags.amp):
                output = model(image)
                loss_value = loss_fn(output, label)
                loss_value /= flags.ga_steps

            if flags.amp:
                scaler.scale(loss_value).backward()
            else:
                loss_value.backward()

            accumulated_steps += 1
            if accumulated_steps % flags.ga_steps == 0:
                if flags.amp:
                    scaler.step(optimizer)
                    scaler.update()
                else:
                    optimizer.step()

                optimizer.zero_grad()
                accumulated_steps = 0

            loss_value = reduce_tensor(loss_value, world_size).detach().cpu().numpy()
            cumulative_loss.append(loss_value)

        mllog_end(key=CONSTANTS.EPOCH_STOP, sync=True,
                  metadata={CONSTANTS.EPOCH_NUM: epoch, 'current_lr': optimizer.param_groups[0]['lr']})

        if flags.lr_decay_epochs:
            scheduler.step()
        if ((epoch % flags.evaluate_every) == 0) and epoch >= flags.start_eval_at:
            del output
            mllog_start(key=CONSTANTS.EVAL_START, value=epoch, metadata={CONSTANTS.EPOCH_NUM: epoch}, sync=True)

            eval_metrics = evaluate(flags, model, val_loader, loss_fn, score_fn, device, epoch)
            eval_metrics["train_loss"] = sum(cumulative_loss) / len(cumulative_loss)

            mllog_event(key=CONSTANTS.EVAL_ACCURACY,
                        value={"epoch": epoch, "value": eval_metrics["mean_dice"]},
                        metadata={CONSTANTS.EPOCH_NUM: epoch},
                        sync=False)
            mllog_end(key=CONSTANTS.EVAL_STOP, metadata={CONSTANTS.EPOCH_NUM: epoch}, sync=True)

            for callback in callbacks:
                callback.on_epoch_end(epoch, eval_metrics, model, optimizer)
            model.train()
            if eval_metrics["mean_dice"] >= flags.quality_threshold:
                stop_training = True

        mllog_end(key=CONSTANTS.BLOCK_STOP, sync=True,
                  metadata={CONSTANTS.FIRST_EPOCH_NUM: epoch, CONSTANTS.EPOCH_COUNT: 1})

        if stop_training:
            break

    mllog_end(key=CONSTANTS.RUN_STOP, sync=True,
              metadata={CONSTANTS.STATUS: CONSTANTS.SUCCESS if stop_training else CONSTANTS.ABORTED})
    for callback in callbacks:
        callback.on_fit_end()
Ejemplo n.º 2
0
def main():
    mllog.config(filename=os.path.join(
        os.path.dirname(os.path.abspath(__file__)), 'unet3d.log'))
    mllog.config(filename=os.path.join("/results", 'unet3d.log'))
    mllogger = mllog.get_mllogger()
    mllogger.logger.propagate = False
    mllog_start(key=constants.INIT_START)

    flags = PARSER.parse_args()
    dllogger = get_dllogger(flags)
    local_rank = flags.local_rank
    device = get_device(local_rank)
    is_distributed = init_distributed()
    world_size = get_world_size()
    local_rank = get_rank()
    worker_seeds, shuffling_seeds = setup_seeds(flags.seed, flags.epochs,
                                                device)
    worker_seed = worker_seeds[local_rank]
    seed_everything(worker_seed)
    mllog_event(key=constants.SEED,
                value=flags.seed if flags.seed != -1 else worker_seed,
                sync=False)

    if is_main_process and flags.verbose:
        mlperf_submission_log()
        mlperf_run_param_log(flags)

    callbacks = get_callbacks(flags, dllogger, local_rank, world_size)
    flags.seed = worker_seed
    model = Unet3D(1,
                   3,
                   normalization=flags.normalization,
                   activation=flags.activation)

    mllog_end(key=constants.INIT_STOP, sync=True)
    mllog_start(key=constants.RUN_START, sync=True)
    train_dataloader, val_dataloader = get_data_loaders(flags,
                                                        num_shards=world_size)
    mllog_event(key=constants.GLOBAL_BATCH_SIZE,
                value=flags.batch_size * world_size,
                sync=False)
    loss_fn = DiceCELoss(to_onehot_y=True,
                         use_softmax=True,
                         layout=flags.layout,
                         include_background=flags.include_background)
    score_fn = DiceScore(to_onehot_y=True,
                         use_argmax=True,
                         layout=flags.layout,
                         include_background=flags.include_background)

    if flags.exec_mode == 'train':
        train(flags,
              model,
              train_dataloader,
              val_dataloader,
              loss_fn,
              score_fn,
              device=device,
              callbacks=callbacks,
              is_distributed=is_distributed)

    elif flags.exec_mode == 'evaluate':
        eval_metrics = evaluate(flags,
                                model,
                                val_dataloader,
                                loss_fn,
                                score_fn,
                                device=device,
                                is_distributed=is_distributed)
        if local_rank == 0:
            for key in eval_metrics.keys():
                print(key, eval_metrics[key])
    else:
        print("Invalid exec_mode.")
        pass