Ejemplo n.º 1
0
def train(experiment_name, distributed=False, continue_epoch=-1):
    model_str = experiment_name

    cfg = load_config_data(experiment_name)
    pprint.pprint(cfg)

    model_type = cfg["model_params"]["model_type"]
    train_params = DotDict(cfg["train_params"])

    checkpoints_dir = f"./checkpoints/{model_str}"
    tensorboard_dir = f"./tensorboard/{model_type}/{model_str}"
    oof_dir = f"./oof/{model_str}"
    os.makedirs(checkpoints_dir, exist_ok=True)
    os.makedirs(tensorboard_dir, exist_ok=True)
    os.makedirs(oof_dir, exist_ok=True)
    print("\n", experiment_name, "\n")

    logger = SummaryWriter(log_dir=tensorboard_dir)

    scaler = torch.cuda.amp.GradScaler()

    with utils.timeit_context("load train"):
        dataset_train = dataset.LyftDatasetPrerendered(
            dset_name=dataset.LyftDataset.DSET_TRAIN_XXL, cfg_data=cfg)

    with utils.timeit_context("load validation"):
        dataset_valid = dataset.LyftDatasetPrerendered(
            dset_name=dataset.LyftDataset.DSET_VALIDATION, cfg_data=cfg)

    batch_size = dataset_train.dset_cfg["batch_size"]

    data_loaders = {
        "train":
        DataLoader(dataset_train,
                   num_workers=16,
                   shuffle=True,
                   batch_size=batch_size),
        "val":
        DataLoader(
            dataset_valid,
            shuffle=False,
            num_workers=16,
            batch_size=dataset_valid.dset_cfg["batch_size"],
        ),
    }
    model_info = DotDict(cfg["model_params"])
    model = build_model(model_info, cfg)
    model = model.cuda()

    model.train()

    initial_lr = float(train_params.initial_lr)
    if train_params.optimizer == "adamp":
        optimizer = AdamP(model.parameters(), lr=initial_lr)
    elif train_params.optimizer == "adam":
        optimizer = optim.Adam(model.parameters(), lr=initial_lr)
    elif train_params.optimizer == "sgd":
        if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE_EMB:
            optimizer = optim.SGD(
                [
                    {
                        "params": [
                            v for n, v in model.named_parameters()
                            if not n.startswith("emb.")
                            and not n.startswith("backbone.")
                        ],
                        "lr":
                        initial_lr * 2,
                    },
                    {
                        "params": model.backbone.parameters(),
                        "lr": initial_lr
                    },
                    {
                        "params": model.emb.parameters(),
                        "lr": initial_lr * 20
                    },
                ],
                lr=initial_lr,
                momentum=0.9,
                nesterov=True,
            )
        else:
            optimizer = optim.SGD(model.parameters(),
                                  lr=initial_lr,
                                  momentum=0.9,
                                  nesterov=True)
    else:
        raise RuntimeError("Invalid optimiser" + train_params.optimizer)

    if continue_epoch > 0:
        checkpoint = torch.load(f"{checkpoints_dir}/{continue_epoch:03}.pt")
        model.load_state_dict(checkpoint["model_state_dict"])
        optimizer.load_state_dict(checkpoint["optimizer_state_dict"])

    nb_epochs = train_params.nb_epochs
    if train_params.scheduler == "steps":
        scheduler = optim.lr_scheduler.MultiStepLR(
            optimizer,
            milestones=train_params.optimiser_milestones,
            gamma=0.2,
            last_epoch=continue_epoch,
        )
    elif train_params.scheduler == "CosineAnnealingLR":
        scheduler = optim.lr_scheduler.CosineAnnealingLR(
            optimizer,
            T_max=nb_epochs,
            eta_min=initial_lr / 1000,
            last_epoch=continue_epoch,
        )
    elif train_params.scheduler == "CosineAnnealingWarmRestarts":
        scheduler = utils.CosineAnnealingWarmRestarts(
            optimizer,
            T_0=train_params.scheduler_period,
            T_mult=train_params.get('scheduler_t_mult', 1),
            eta_min=initial_lr / 1000.0,
            last_epoch=-1)
        for i in range(continue_epoch + 1):
            scheduler.step()
    else:
        raise RuntimeError("Invalid scheduler name")

    grad_clip_value = train_params.get("grad_clip", 2.0)
    print("grad clip:", grad_clip_value)

    print(
        f"Num training agents: {len(dataset_train)} validation agents: {len(dataset_valid)}"
    )

    for epoch_num in range(continue_epoch + 1, nb_epochs + 1):
        for phase in ["train", "val"]:
            model.train(phase == "train")
            epoch_loss_segmentation = []
            epoch_loss_regression = []
            epoch_loss_regression_aux = []
            data_loader = data_loaders[phase]

            optimizer.zero_grad()

            if phase == "train":
                nb_steps_per_epoch = train_params.epoch_size // batch_size
                data_iter = tqdm(
                    utils.LoopIterable(data_loader,
                                       max_iters=nb_steps_per_epoch),
                    total=nb_steps_per_epoch,
                    ncols=250,
                )
            else:
                if epoch_num % 2 > 0:  # skip each 4th validation for speed
                    continue

                data_iter = tqdm(data_loader, ncols=250)

            for data in data_iter:
                with torch.set_grad_enabled(phase == "train"):
                    # torch.set_anomaly_enabled(True)
                    inputs = data["image"].float().cuda()
                    # agent_state = data["agent_state"].float().cuda()
                    agent_state = None
                    target_availabilities = data["target_availabilities"].cuda(
                    )

                    targets = data["target_positions"].cuda()

                    pos_scale = 1.0

                    optimizer.zero_grad()

                    loss_segmentation = 0
                    loss_regression = 0
                    loss_regression_aux = 0

                    if model_type == MODEL_TYPE_ATTENTION:
                        all_agents_state = data["all_agents_state"].float(
                        ).cuda()
                        image_blocks_positions_agent = data[
                            "image_blocks_positions_agent"].cuda()

                        with torch.cuda.amp.autocast():
                            pred, confidences = model(
                                inputs, image_blocks_positions_agent,
                                all_agents_state)

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch(
                                gt=targets.float() * pos_scale,
                                pred=pred.float() * pos_scale,
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE_WITH_OTHER_AGENTS_INPUTS:
                        all_agents_state = data["all_agents_state"].float(
                        ).cuda()

                        with torch.cuda.amp.autocast():
                            pred, confidences = model(inputs, all_agents_state)

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch(
                                gt=targets.float(),
                                pred=pred.float(),
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE:
                        with torch.cuda.amp.autocast():
                            pred, confidences = model(inputs, agent_state)

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch_from_log_sm(
                                gt=targets.float() * pos_scale,
                                pred=pred.float() * pos_scale,
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE_AUX_OUT:
                        with torch.cuda.amp.autocast():
                            pred, confidences, pred_aux, confidences_aux = model(
                                inputs, agent_state,
                                data["image_4x"].float().cuda())

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch_from_log_sm(
                                gt=targets.float(),
                                pred=pred.float(),
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                            loss_regression_aux = utils.pytorch_neg_multi_log_likelihood_batch_from_log_sm(
                                gt=targets.float(),
                                pred=pred_aux.float(),
                                confidences=confidences_aux.float(),
                                avails=target_availabilities.float(),
                            )

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE_I4X:
                        with torch.cuda.amp.autocast():
                            pred, confidences = model(
                                inputs, agent_state,
                                data["image_4x"].float().cuda())

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch(
                                gt=targets.float() * pos_scale,
                                pred=pred.float() * pos_scale,
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE_WITH_MASKS:
                        with torch.cuda.amp.autocast():
                            pred, confidences = model(
                                inputs, agent_state,
                                data["other_agents_masks"].float().cuda())

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch(
                                gt=targets.float() * pos_scale,
                                pred=pred.float() * pos_scale,
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )

                    if model_type == MODEL_TYPE_REGRESSION_MULTI_MODE_EMB:
                        with torch.cuda.amp.autocast():
                            pred, confidences = model(
                                inputs, agent_state,
                                data["corners"].float().cuda())

                            loss_regression = utils.pytorch_neg_multi_log_likelihood_batch(
                                gt=targets.float() * pos_scale,
                                pred=pred.float() * pos_scale,
                                confidences=confidences.float(),
                                avails=target_availabilities.float(),
                            )
                    elif model_type == MODEL_TYPE_SEGMENTATION:
                        target_mask = data["output_mask"].cuda()
                        l2_cls, l1_cls = model(inputs, agent_state)
                        loss_segmentation = (torch.nn.functional.
                                             binary_cross_entropy_with_logits(
                                                 l2_cls, target_mask) * 1000 +
                                             torch.nn.functional.
                                             binary_cross_entropy_with_logits(
                                                 l1_cls, target_mask) * 100)
                    elif model_type == MODEL_TYPE_SEGMENTATION_AND_REGRESSION:
                        target_mask = data["output_mask"].cuda()
                        segmentation, pred, confidences = model(
                            inputs, agent_state)
                        loss_segmentation = (torch.nn.functional.
                                             binary_cross_entropy_with_logits(
                                                 segmentation, target_mask) *
                                             1000)

                        loss_regression = utils.pytorch_neg_multi_log_likelihood_batch(
                            gt=targets.float() * pos_scale,
                            pred=pred.float() * pos_scale,
                            confidences=confidences.float(),
                            avails=target_availabilities.float(),
                        )

                    loss = loss_segmentation + loss_regression + loss_regression_aux

                    if phase == "train":
                        scaler.scale(loss).backward()

                        # Unscales the gradients of optimizer's assigned params in-place
                        scaler.unscale_(optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                       grad_clip_value)
                        # optimizer's gradients are already unscaled, so scaler.step does not unscale them,
                        # although it still skips optimizer.step() if the gradients contain infs or NaNs.
                        scaler.step(optimizer)
                        scaler.update()

                    if phase == "val":
                        # save predictions visualisation
                        pass

                    epoch_loss_segmentation.append(float(loss_segmentation))
                    epoch_loss_regression.append(float(loss_regression))
                    epoch_loss_regression_aux.append(
                        float(loss_regression_aux))
                    loss_segmentation = None
                    loss_regression = None
                    loss_regression_aux = None
                    del loss

                    data_iter.set_description(
                        f"{epoch_num} {phase[0]}"
                        f" Loss r {np.mean(epoch_loss_regression):1.4f} "
                        f" r aux {np.mean(epoch_loss_regression_aux):1.4f} "
                        f"s {np.mean(epoch_loss_segmentation):1.4f}")

            logger.add_scalar(f"loss_{phase}", np.mean(epoch_loss_regression),
                              epoch_num)
            if epoch_loss_segmentation[-1] > 0:
                logger.add_scalar(f"loss_segmentation_{phase}",
                                  np.mean(epoch_loss_segmentation), epoch_num)

            if epoch_loss_regression_aux[-1] > 0:
                logger.add_scalar(f"loss_regression_aux_{phase}",
                                  np.mean(epoch_loss_regression_aux),
                                  epoch_num)

            if phase == "train":
                logger.add_scalar("lr", optimizer.param_groups[0]["lr"],
                                  epoch_num)
            logger.flush()

            if phase == "train":
                scheduler.step()
                if (epoch_num % train_params.save_period
                        == 0) or (epoch_num == nb_epochs):
                    torch.save(
                        {
                            "epoch":
                            epoch_num,
                            "model_state_dict":
                            model.module.state_dict()
                            if distributed else model.state_dict(),
                            "optimizer_state_dict":
                            optimizer.state_dict(),
                        },
                        f"{checkpoints_dir}/{epoch_num:03}.pt",
                    )
Ejemplo n.º 2
0
def main():
    # save input stats for later use

    if args.server == 'server_A':
        work_dir = os.path.join('/data1/JM/lung_segmentation', args.exp)
        print(work_dir)
    elif args.server == 'server_B':
        work_dir = os.path.join('/data1/workspace/JM_gen/lung-seg-back-up',
                                args.exp)
        print(work_dir)
    elif args.server == 'server_D':
        work_dir = os.path.join(
            '/daintlab/home/woans0104/workspace/'
            'lung-seg-back-up', args.exp)
        print(work_dir)
    if not os.path.exists(work_dir):
        os.makedirs(work_dir)

    # copy this file to work dir to keep training configuration
    shutil.copy(__file__, os.path.join(work_dir, 'main.py'))
    with open(os.path.join(work_dir, 'args.pkl'), 'wb') as f:
        pickle.dump(args, f)

    source_dataset, target_dataset1, target_dataset2 \
        = loader.dataset_condition(args.source_dataset)

    # 1.load_dataset
    train_loader_source,test_loader_source \
        = loader.get_loader(server=args.server,
                            dataset=source_dataset,
                            train_size=args.train_size,
                            aug_mode=args.aug_mode,
                            aug_range=args.aug_range,
                            batch_size=args.batch_size,
                            work_dir=work_dir)

    train_loader_target1, _ = loader.get_loader(server=args.server,
                                                dataset=target_dataset1,
                                                train_size=1,
                                                aug_mode=False,
                                                aug_range=args.aug_range,
                                                batch_size=1,
                                                work_dir=work_dir)

    train_loader_target2, _ = loader.get_loader(server=args.server,
                                                dataset=target_dataset2,
                                                train_size=1,
                                                aug_mode=False,
                                                aug_range=args.aug_range,
                                                batch_size=1,
                                                work_dir=work_dir)

    test_data_li = [
        test_loader_source, train_loader_target1, train_loader_target2
    ]

    trn_logger = Logger(os.path.join(work_dir, 'train.log'))
    trn_raw_logger = Logger(os.path.join(work_dir, 'train_raw.log'))
    val_logger = Logger(os.path.join(work_dir, 'validation.log'))

    # 2.model_select
    #model_seg = select_model(args.arch)

    if args.arch == 'unet':
        model_seg = Unet2D(in_shape=(1, 256, 256))
    elif args.arch == 'unet_norm':
        model_seg = Unet2D_norm(in_shape=(1, 256, 256),
                                nomalize_con=args.nomalize_con,
                                affine=args.affine,
                                group_channel=args.group_channel,
                                weight_std=args.weight_std)

    else:
        raise ValueError('Not supported network.')

    model_seg = model_seg.cuda()

    # 3.gpu select
    model_seg = nn.DataParallel(model_seg)
    cudnn.benchmark = True

    # 4.optim

    if args.optim == 'adam':
        optimizer_seg = torch.optim.Adam(model_seg.parameters(),
                                         lr=args.lr,
                                         weight_decay=args.weight_decay,
                                         eps=args.eps)

    elif args.optim == 'adamp':
        optimizer_seg = AdamP(model_seg.parameters(),
                              lr=args.lr,
                              weight_decay=args.weight_decay,
                              eps=args.eps)
    elif args.optim == 'sgd':
        optimizer_seg = torch.optim.SGD(model_seg.parameters(),
                                        lr=args.lr,
                                        momentum=0.9,
                                        weight_decay=args.weight_decay)

    # lr decay
    lr_schedule = args.lr_schedule
    lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer_seg,
                                                  milestones=lr_schedule[:-1],
                                                  gamma=0.1)

    # 5.loss

    if args.loss_function == 'bce':
        criterion = nn.BCELoss()
    elif args.loss_function == 'bce_logit':
        criterion = nn.BCEWithLogitsLoss()
    elif args.loss_function == 'dice':
        criterion = DiceLoss()
    elif args.loss_function == 'Cldice':
        bce = nn.BCEWithLogitsLoss().cuda()
        dice = DiceLoss().cuda()
        criterion = ClDice(bce, dice, alpha=1, beta=1)

    criterion = criterion.cuda()

    ###############################################################################

    # train

    best_iou = 0
    try:
        if args.train_mode:
            for epoch in range(lr_schedule[-1]):

                train(model=model_seg,
                      train_loader=train_loader_source,
                      epoch=epoch,
                      criterion=criterion,
                      optimizer=optimizer_seg,
                      logger=trn_logger,
                      sublogger=trn_raw_logger)

                iou = validate(model=model_seg,
                               val_loader=test_loader_source,
                               epoch=epoch,
                               criterion=criterion,
                               logger=val_logger)
                print('validation_result ************************************')

                lr_scheduler.step()

                if args.val_size == 0:
                    is_best = 1
                else:
                    is_best = iou > best_iou
                best_iou = max(iou, best_iou)
                save_checkpoint(
                    {
                        'epoch': epoch + 1,
                        'state_dict': model_seg.state_dict(),
                        'optimizer': optimizer_seg.state_dict()
                    },
                    is_best,
                    work_dir,
                    filename='checkpoint.pth')

        print("train end")
    except RuntimeError as e:
        print('error message : {}'.format(e))

        import ipdb
        ipdb.set_trace()

    draw_curve(work_dir, trn_logger, val_logger)
    # here is load model for last pth
    check_best_pth(work_dir)

    # validation
    if args.test_mode:
        print('Test mode ...')
        main_test(model=model_seg, test_loader=test_data_li, args=args)