Esempio n. 1
0
def val(model, criterion, config, data_config, n_channel, logger, writer, global_step, device):
    model.eval()
    dataloader_3d = create_loader_3d(data_config, 'val')
    val_loss = 0
    val_dice_loss = 0
    val_attn_loss_dict = {}
    n_batch_3d = len(dataloader_3d)
    with torch.no_grad():
        with tqdm(total=n_batch_3d, desc="Validation execution.", unit='batch') as pbar:
            for idx_3d, batch_3d in enumerate(dataloader_3d):
                dataloader_2d = create_loader_2d(batch_3d, data_config, 'val')
                for idx, batch_2d in enumerate(dataloader_2d):
                    img = batch_2d['img'].to(device=device, dtype=torch.float32)  # [N, n_channel, H, W]
                    mask_gt = batch_2d['mask'].to(device=device, dtype=torch.float32)  # [N, H, W]
                    mask_flag = batch_2d['mask_flag'].to(device=device, dtype=torch.float32)
                    attention_map_out = model(img)
                    loss_attn_map, loss_dict_attn_map = criterion(pred=attention_map_out, target=mask_gt,
                                                                  target_roi_weight=mask_flag, deep_supervision=True,
                                                                  need_sigmoid=False,
                                                                  layer_weight=config['loss']['attention_loss_weight'])
                    loss = loss_attn_map
                    loss_scalar = loss.detach().item()
                    loss_dice_scalar = loss_dict_attn_map['layer_0']["dice_loss"]
                    val_loss += loss_scalar
                    val_dice_loss += loss_dice_scalar

                    for key, value in loss_dict_attn_map.items():
                        val_attn_loss_dict.setdefault(key, dict())
                        val_attn_loss_dict[key].setdefault("epoch_attn_loss_dice", 0)
                        val_attn_loss_dict[key]["epoch_attn_loss_dice"] += value["dice_loss"]

                    pbar.set_postfix(**{'loss (batch)': loss_scalar, 'loss_dice': loss_dice_scalar})
                logger.info(
                    f"\tBatch: {idx_3d}/{n_batch_3d}, Loss: {loss_scalar}, Dice_loss: {loss_dice_scalar}")
                pbar.update()
            writer.add_scalar('Loss_val/val', loss_scalar, global_step)
            writer.add_scalar('Loss_val/val_dice', loss_dice_scalar, global_step)
            for key, value in loss_dict_attn_map.items():
                writer.add_scalar(f'Loss_val/val_dice/attention_{key}', value["dice_loss"], global_step)

            writer.add_images('val/images', torch.unsqueeze(img[:, n_channel // 2], 1), global_step)
            writer.add_images('val/masks_gt', torch.sum(mask_gt[:, :-2], dim=1, keepdim=True), global_step)
            for r_i, roi_name in enumerate((data_config['dataset']['3d']['roi_names'] + ["issue", "air"])):
                writer.add_images(f'val/masks_{roi_name}_gt', mask_gt[:, r_i:r_i + 1], global_step)
                writer.add_images(f'val/masks_{roi_name}_pred', attention_map_out[0][:, r_i:r_i + 1], global_step)
            if data_config['dataset']['3d']['with_issue_air_mask']:
                writer.add_images('val/masks_pred',
                                  torch.sum(attention_map_out[0][:, :-2], dim=1, keepdim=True), global_step)
            else:
                writer.add_images('val/masks_pred',
                                  torch.sum(attention_map_out[0], dim=1, keepdim=True), global_step)
            for l_i, attn_map_single in enumerate(attention_map_out):
                for r_i, roi_name in enumerate(data_config['dataset']['3d']['roi_names'] + ["issue", "air"]):
                    writer.add_images(f'val/masks_{roi_name}_pred_layer_{l_i}',
                                      attn_map_single[:, r_i:r_i + 1], global_step)
    model.train()
    return val_loss, val_dice_loss, val_attn_loss_dict
Esempio n. 2
0
def val(model, criterion, roi_names, data_config, n_channel, logger, writer, global_step, device):
    model.eval()
    dataloader_3d = create_loader_3d(data_config, 'val')
    val_loss = 0
    val_dice_loss = 0
    n_batch_3d = len(dataloader_3d)
    with torch.no_grad():
        with tqdm(total=n_batch_3d, desc="Validation execution.", unit='batch') as pbar:
            for idx_3d, batch_3d in enumerate(dataloader_3d):
                dataloader_2d = create_loader_2d(batch_3d, data_config, 'val')
                for idx, batch_2d in enumerate(dataloader_2d):
                    img = batch_2d['img'].to(device=device, dtype=torch.float32)  # [N, n_channel, H, W]
                    mask_gt = batch_2d['mask'].to(device=device, dtype=torch.float32)  # [N, H, W]
                    mask_flag = batch_2d['mask_flag'].to(device=device, dtype=torch.float32)
                    out = model(img)
                    loss, loss_dict = criterion(pred=out, target=mask_gt, target_roi_weight=mask_flag,
                                                need_sigmoid=False, for_val=True)
                    loss_scalar = loss.detach().item()
                    loss_dice_scalar = loss_dict["dice_loss"]
                    val_loss += loss_scalar
                    val_dice_loss += loss_dice_scalar

                    pbar.set_postfix(**{'loss (batch)': loss_scalar, 'loss_dice': loss_dice_scalar})
                logger.info(f"\tBatch: {idx_3d}/{n_batch_3d}, Loss: {loss_scalar}, Dice_loss: {loss_dice_scalar}")
                pbar.update()
            writer.add_scalar('Loss_val/val', loss_scalar, global_step)
            writer.add_scalar('Loss_val/val_dice', loss_dice_scalar, global_step)
            writer.add_images('val/images', torch.unsqueeze(img[:, n_channel // 2], 1), global_step)
            for r_i, roi_name in enumerate(roi_names):
                writer.add_images(f'val/masks_{roi_name}_gt', mask_gt[:, r_i:r_i + 1], global_step)
                writer.add_images(f'val/masks_{roi_name}_pred', out[0][:, r_i:r_i + 1], global_step)
            if data_config['dataset']['3d']['with_issue_air_mask']:
                writer.add_images('val/masks_gt', torch.sum(mask_gt[:, :-2], dim=1, keepdim=True),
                                  global_step)
                writer.add_images('val/masks_pred',
                                  torch.sum(out[0][:, :-2], dim=1, keepdim=True),
                                  global_step)
            elif data_config['dataset']['3d']['with_background']:
                writer.add_images('val/masks_gt', torch.sum(mask_gt[:, :-1], dim=1, keepdim=True), global_step)
                writer.add_images('val/masks_pred',
                                  torch.sum(out[0][:, :-1], dim=1, keepdim=True), global_step)
            else:
                writer.add_images('val/masks_gt', torch.sum(mask_gt, dim=1, keepdim=True), global_step)
                writer.add_images('val/masks_pred',
                                  torch.sum(out[0], dim=1, keepdim=True), global_step)

    model.train()
    return val_loss, val_dice_loss
Esempio n. 3
0
def pred_with_model(model, ckpt_dir, ckpt_fn, pred_save_dir, config,
                    data_config, roi_names, device):
    model = load_checkpoint_model(model, ckpt_dir, ckpt_fn, device)

    ###################################################################################
    # Dataset
    ###################################################################################
    dataloader_3d = create_loader_3d(data_config, 'pred')
    n_batch_3d = len(dataloader_3d)
    with torch.no_grad():
        with tqdm(total=n_batch_3d, desc=f"Predicting", unit='batch') as pbar:
            for batch_3d in dataloader_3d:
                assert len(batch_3d
                           ) == 1, 'len(batch_3d) in pred.py must be set to 1.'
                pid = batch_3d[0]['pid']
                mask_pred_all = torch.zeros_like(
                    batch_3d[0]['mask'],
                    dtype=torch.bool).to(device)  # shape -> (N_roi, D, H, W)
                dataloader_2d = create_loader_2d(batch_3d, data_config, 'pred')
                for idx, batch_2d in enumerate(dataloader_2d):
                    img = batch_2d['img'].to(
                        device=device,
                        dtype=torch.float32)  # [N, n_channel, H, W]
                    target_slice = batch_2d['target_slice']  # shape -> [N]
                    attention_map_out = model(
                        img)  # shape -> (N, C, H, W), C is N_roi
                    mask_pred = attention_map_out[0]
                    mask_pred = torch.transpose(mask_pred, 0,
                                                1)  # shape -> (C, N, H, W)
                    mask_pred_all[:, target_slice] = mask_pred.gt(0.5)

                os.makedirs(config['pred_save_dir'], exist_ok=True)
                if data_config['dataset']['3d']['with_issue_air_mask']:
                    roi_names += ["issue", "air"]
                for i, roi in enumerate(roi_names):
                    nrrd.write(
                        os.path.join(pred_save_dir, f'{pid}_{roi}.nrrd'),
                        mask_pred_all[i].numpy().astype(np.uint8))
                pbar.update()
Esempio n. 4
0
def main():
    parser = argparse.ArgumentParser(description='SegTransformer training')
    parser.add_argument('--config', type=str, required=True)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = load_config_yaml(args.config)
    data_config = load_config_yaml(config['data_config'])

    now = datetime.datetime.now()
    date_time = now.strftime("%Y-%m-%d-%H-%M")
    os.makedirs(config['logging_dir'], exist_ok=True)
    logging_path = os.path.join(config['logging_dir'],
                                f'logging_train_{date_time}.txt')
    logger = create_logger(logging_path, stdout=False)

    ###################################################################################
    # construct net
    ###################################################################################
    n_channel = data_config['dataset']['2d']['n_slice']
    n_class = len(data_config['dataset']['3d']['roi_names'])
    roi_names = data_config['dataset']['3d']['roi_names']
    if data_config['dataset']['3d']['with_issue_air_mask']:
        n_class += 2
        roi_names += ['issue_mask', 'air_mask']
    if data_config['dataset']['3d']['with_background']:
        n_class += 1
        roi_names += ['bg']
    start_channel = int(config['start_channel'])
    logger.info(
        f'create model with n_channel={n_channel}, start_channel={start_channel}, n_class={n_class}'
    )

    model = SegTransformer(n_channel=n_channel,
                           start_channel=start_channel,
                           n_class=n_class,
                           input_dim=(512, 512),
                           nhead=4,
                           normalization='bn',
                           activation='relu',
                           num_groups=None,
                           with_background=data_config['dataset']['3d']
                           ['with_background']).to(device)

    logger.info(f"model_dir: {config['ckpt_dir']}")

    ###################################################################################
    # criterion, optimizer, scheduler
    ###################################################################################
    criterion = Criterion(config)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=config['lr'],
                                  weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=config['step_size'])
    if config['deep_supervision']:
        logger.info('Train model using deep supervision')
    else:
        logger.info('Train model using deep supervision')

    ###################################################################################
    # SummaryWriter
    ###################################################################################
    logger.info("Creating writer")
    writer = SummaryWriter(comment=f"LR_{config['lr']}_BS_{config['n_epoch']}")

    ###################################################################################
    # train setup
    ###################################################################################
    global_step = 0
    best_loss = np.inf
    epoch_start = 0

    ###################################################################################
    # load previous model
    ###################################################################################
    if config['load_checkpoint']:
        logger.info(
            f'Loading model from {os.path.join(config["ckpt_dir"], config["ckpt_fn"])}...'
        )
        model, optimizer, scheduler, epoch_start, global_step = load_checkpoint(
            model, optimizer, scheduler, config['ckpt_dir'], config['ckpt_fn'],
            device)
    elif config['load_checkpoint_encoder']:
        logger.info(
            f'Loading encoder from {os.path.join(config["ckpt_dir"], config["ckpt_fn"])}...'
        )
        model.encoder = load_checkpoint_encoder(model.encoder,
                                                ckpt_dir=config['ckpt_dir'],
                                                ckpt_fn=config['ckpt_fn'],
                                                device=device)
        if config['freeze_encoder']:
            logger.info('Freeze encoder')
            freeze(model.encoder)
    elif config['load_checkpoint_decoder']:
        logger.info(
            f'Loading decoder from {os.path.join(config["ckpt_dir"], config["ckpt_fn"])}...'
        )
        model.decoder = load_checkpoint_decoder(model.decoder,
                                                ckpt_dir=config['ckpt_dir'],
                                                ckpt_fn=config['ckpt_fn'],
                                                device=device)
        if config['freeze_decoder']:
            logger.info('Freeze decoder')
            freeze(model.decoder)

    ###################################################################################
    # parallel model and data
    ###################################################################################
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = torch.nn.DataParallel(model)

    ###################################################################################
    # Dataset
    ###################################################################################
    dataloader_3d = create_loader_3d(data_config, 'train')
    ###################################################################################
    # train
    ###################################################################################
    logger.info(f'Starting training from epoch: {epoch_start}')
    for epoch in range(epoch_start, config['n_epoch']):
        logger.info(f"Epoch: {epoch}/{config['n_epoch']}")
        epoch_loss = 0
        epoch_loss_dice = 0
        epoch_attn_loss_dict = {}
        n_batch_3d = len(dataloader_3d)
        with tqdm(total=n_batch_3d,
                  desc=f"Epoch {epoch + 1}/{config['n_epoch']}",
                  unit='batch') as pbar:
            for batch_3d in dataloader_3d:
                dataloader_2d = create_loader_2d(batch_3d, data_config,
                                                 'train')
                n_batch_2d = len(dataloader_2d)
                for idx, batch_2d in enumerate(dataloader_2d):
                    img = batch_2d['img'].to(
                        device=device,
                        dtype=torch.float32)  # [N, n_channel, H, W]
                    mask_gt = batch_2d['mask'].to(
                        device=device, dtype=torch.float32)  # [N, H, W]
                    attention_map_out = model(img)
                    mask_flag = batch_2d['mask_flag'].to(device=device,
                                                         dtype=torch.float32)

                    loss_attn_map, loss_dict_attn_map = criterion(
                        pred=attention_map_out,
                        target=mask_gt,
                        target_roi_weight=mask_flag,
                        deep_supervision=True,
                        need_sigmoid=False,
                        layer_weight=config['loss']['attention_loss_weight'])
                    optimizer.zero_grad()
                    loss = loss_attn_map
                    loss.backward()
                    torch.nn.utils.clip_grad_value_(model.parameters(), 0.01)
                    optimizer.step()

                    global_step += 1
                    loss_scalar = loss.detach().item()
                    loss_dice_scalar = loss_dict_attn_map['layer_0'][
                        'dice_loss']
                    epoch_loss += loss_scalar
                    epoch_loss_dice += loss_dice_scalar

                    for key, value in loss_dict_attn_map.items():
                        epoch_attn_loss_dict.setdefault(key, dict())
                        epoch_attn_loss_dict[key].setdefault(
                            "epoch_attn_loss_dice", 0)
                        epoch_attn_loss_dict[key][
                            "epoch_attn_loss_dice"] += value["dice_loss"]

                    postfix_dict = {
                        'loss (batch)': loss_scalar,
                        'loss_dice': loss_dice_scalar,
                        'global_step': global_step
                    }
                    pbar.set_postfix(**postfix_dict)
                    if (global_step + 1) % (
                            config['write_summary_loss_batch_step']) == 0:
                        postfix_dict.update({
                            f'loss_attention_{key}': value["dice_loss"]
                            for key, value in loss_dict_attn_map.items()
                        })
                        print(postfix_dict)
                        logger.info(
                            f"\tBatch: {idx}/{n_batch_2d}, Loss: {loss_scalar}, Dice_loss: {loss_dice_scalar}"
                        )
                        writer.add_scalar('Loss_train/train', loss_scalar,
                                          global_step)
                        writer.add_scalar('Loss_train/train_dice',
                                          loss_dice_scalar, global_step)

                        for key, value in loss_dict_attn_map.items():
                            writer.add_scalar(
                                f'Loss_train/train_dice/attention_{key}',
                                value["dice_loss"], global_step)

                    if (global_step +
                            1) % (config['write_summary_2d_batch_step']) == 0:
                        writer.add_images(
                            'train/images',
                            torch.unsqueeze(img[:, n_channel // 2], 1),
                            global_step)

                        for r_i, roi_name in enumerate(roi_names):
                            writer.add_images(f'train/masks_{roi_name}_gt',
                                              mask_gt[:, r_i:r_i + 1],
                                              global_step)
                            writer.add_images(
                                f'train/masks_{roi_name}_pred',
                                attention_map_out[0][:,
                                                     r_i:r_i + 1], global_step)
                        if data_config['dataset']['3d']['with_issue_air_mask']:
                            writer.add_images(
                                'train/masks_gt',
                                torch.sum(mask_gt[:, :-2], dim=1,
                                          keepdim=True), global_step)
                            writer.add_images(
                                'train/masks_pred',
                                torch.sum(attention_map_out[0][:, :-2],
                                          dim=1,
                                          keepdim=True), global_step)
                        elif data_config['dataset']['3d']['with_background']:
                            writer.add_images(
                                'train/masks_gt',
                                torch.sum(mask_gt[:, :-1], dim=1,
                                          keepdim=True), global_step)
                            writer.add_images(
                                'train/masks_pred',
                                torch.sum(attention_map_out[0][:, :-1],
                                          dim=1,
                                          keepdim=True), global_step)
                        else:
                            writer.add_images(
                                'train/masks_gt',
                                torch.sum(mask_gt, dim=1, keepdim=True),
                                global_step)
                            writer.add_images(
                                'train/masks_pred',
                                torch.sum(attention_map_out[0],
                                          dim=1,
                                          keepdim=True), global_step)
                        for l_i, attn_map_single in enumerate(
                                attention_map_out):
                            for r_i, roi_name in enumerate(roi_names):
                                writer.add_images(
                                    f'train/masks_{roi_name}_pred_layer_{l_i}',
                                    attn_map_single[:,
                                                    r_i:r_i + 1], global_step)
                pbar.update()

            scheduler.step()
            # log epoch loss
            if (epoch + 1) % config['logging_epoch_step'] == 0:
                writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
                writer.add_scalar('Loss_epoch_train/train', epoch_loss, epoch)
                writer.add_scalar('Loss_epoch_train/train_dice',
                                  epoch_loss_dice, epoch)
                for key, value in epoch_attn_loss_dict.items():
                    writer.add_scalar(
                        f'Loss_epoch_train/train_dice/attention_{key}',
                        value["epoch_attn_loss_dice"], epoch)
                logger.info(
                    f"Epoch: {epoch}/{config['n_epoch']}, Train Loss: {epoch_loss}, Train Loss DSC: {epoch_loss_dice}"
                )

                os.makedirs(config['ckpt_dir'], exist_ok=True)
                save_checkpoint(model=model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                epoch=epoch,
                                global_step=global_step,
                                ckpt_dir=config['ckpt_dir'],
                                ckpt_fn=f'ckpt_{date_time}_Epoch_{epoch}.ckpt')

            # validation and save model
            if (epoch + 1) % config['val_model_epoch_step'] == 0:
                val_loss, val_dice_loss, val_attn_loss_dict = val(
                    model, criterion, config, roi_names, data_config,
                    n_channel, logger, writer, global_step, device)
                writer.add_scalar('Loss_epoch_val/val', val_loss, epoch)
                writer.add_scalar('Loss_epoch_val/val_dice', val_dice_loss,
                                  epoch)
                for key, value in val_attn_loss_dict.items():
                    writer.add_scalar(
                        f'Loss_epoch_val/val_dice/attention_{key}',
                        value["epoch_attn_loss_dice"], epoch)

                logger.info(
                    f"Epoch: {epoch}/{config['n_epoch']}, Validation Loss: {val_loss}, Validation Loss Dice: {val_dice_loss}"
                )

                if best_loss > val_dice_loss:
                    best_loss = val_dice_loss
                    for filename in glob.glob(
                            os.path.join(config['ckpt_dir'], "best_ckpt*")):
                        os.remove(filename)
                    save_checkpoint(
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epoch=epoch,
                        global_step=global_step,
                        ckpt_dir=config['ckpt_dir'],
                        ckpt_fn=f'best_ckpt_{date_time}_epoch_{epoch}.ckpt')

        if config['freeze_encoder'] and config[
                'unfreeze_encoder_epoch'] is not None:
            if epoch >= int(config['unfreeze_encoder_epoch']):
                unfreeze(model.module.encoder)
                config['unfreeze_encoder_epoch'] = None
                logger.info(f'Unfreeze encoder at {epoch}')
        if config['freeze_decoder'] and config[
                'unfreeze_decoder_epoch'] is not None:
            if epoch >= int(config['unfreeze_decoder_epoch']):
                unfreeze(model.module.decoder)
                config['unfreeze_decoder_epoch'] = None
                logger.info(f'Unfreeze decoder at {epoch}')
    writer.close()
Esempio n. 5
0
def main():
    parser = argparse.ArgumentParser(description='SegTransformer training')
    parser.add_argument('--config', type=str, required=True)
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    config = load_config_yaml(args.config)
    data_config = load_config_yaml(config['data_config'])

    now = datetime.datetime.now()
    date_time = now.strftime("%Y-%m-%d-%H-%M")
    os.makedirs(config['logging_dir'], exist_ok=True)
    logging_path = os.path.join(config['logging_dir'],
                                f'logging_train_{date_time}.txt')
    logger = create_logger(logging_path, stdout=False)

    ###################################################################################
    # construct net
    ###################################################################################
    n_channel = data_config['dataset']['2d']['n_slice']
    n_class = len(data_config['dataset']['3d']['roi_names'])
    if data_config['dataset']['3d']['with_issue_air_mask']:
        n_class += 2
    start_channel = int(config['start_channel'])
    logger.info(
        f'create model with n_channel={n_channel}, start_channel={start_channel}, n_class={n_class}'
    )

    model = SegTransformer(
        n_channel=n_channel,
        start_channel=start_channel,
        n_class=n_class,
        deep_supervision=config["deep_supervision"]).to(device)

    logger.info(f"model_dir: {config['ckpt_dir']}")

    ###################################################################################
    # criterion, optimizer, scheduler
    ###################################################################################
    criterion = Criterion(config)
    optimizer = torch.optim.AdamW(model.parameters(),
                                  lr=config['lr'],
                                  weight_decay=config['weight_decay'])
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=config['step_size'])
    if config['deep_supervision']:
        logger.info('Train model using deep supervision')
    else:
        logger.info('Train model using deep supervision')

    ###################################################################################
    # SummaryWriter
    ###################################################################################
    logger.info("Creating writer")
    writer = SummaryWriter(comment=f"LR_{config['lr']}_BS_{config['n_epoch']}")

    ###################################################################################
    # train setup
    ###################################################################################
    global_step = 0
    best_loss = np.inf
    epoch_start = 0

    ###################################################################################
    # load previous model
    ###################################################################################
    if config['load_checkpoint']:
        logger.info(
            f'Loading model from {os.path.join(config["ckpt_dir"], config["ckpt_fn"])}...'
        )
        model, optimizer, scheduler, epoch_start, global_step = load_checkpoint(
            model, optimizer, scheduler, config['ckpt_dir'], config['ckpt_fn'],
            device)
    elif config['load_checkpoint_encoder']:
        logger.info(
            f'Loading encoder from {os.path.join(config["ckpt_dir"], config["ckpt_fn"])}...'
        )
        model.encoder = load_checkpoint_encoder(model.encoder,
                                                ckpt_dir=config['ckpt_dir'],
                                                ckpt_fn=config['ckpt_fn'],
                                                device=device)
        if config['freeze_encoder']:
            logger.info('Freeze encoder')
            freeze(model.encoder)
    elif config['load_checkpoint_decoder']:
        logger.info(
            f'Loading decoder from {os.path.join(config["ckpt_dir"], config["ckpt_fn"])}...'
        )
        model.decoder = load_checkpoint_decoder(model.decoder,
                                                ckpt_dir=config['ckpt_dir'],
                                                ckpt_fn=config['ckpt_fn'],
                                                device=device)
        if config['freeze_decoder']:
            logger.info('Freeze decoder')
            freeze(model.decoder)

    ###################################################################################
    # parallel model and data
    ###################################################################################
    if torch.cuda.device_count() > 1:
        print("Let's use", torch.cuda.device_count(), "GPUs!")
        # dim = 0 [30, xxx] -> [10, ...], [10, ...], [10, ...] on 3 GPUs
        model = torch.nn.DataParallel(model)

    ###################################################################################
    # Dataset
    ###################################################################################
    dataloader_3d = create_loader_3d(data_config, 'train')
    ###################################################################################
    # train
    ###################################################################################
    logger.info(f'Starting training from epoch: {epoch_start}')
    for epoch in range(epoch_start, config['n_epoch']):
        logger.info(f"Epoch: {epoch}/{config['n_epoch']}")
        epoch_loss = 0
        epoch_loss_focal = 0
        epoch_loss_dice = 0
        n_batch_3d = len(dataloader_3d)
        with tqdm(total=n_batch_3d,
                  desc=f"Epoch {epoch + 1}/{config['n_epoch']}",
                  unit='batch') as pbar:
            for batch_3d in dataloader_3d:
                dataloader_2d = create_loader_2d(batch_3d, data_config,
                                                 'train')
                n_batch_2d = len(dataloader_2d)
                for idx, batch_2d in enumerate(dataloader_2d):
                    img = batch_2d['img'].to(
                        device=device,
                        dtype=torch.float32)  # [N, n_channel, H, W]
                    mask_gt = batch_2d['mask'].to(
                        device=device, dtype=torch.float32)  # [N, H, W]
                    mask_pred = model(img)
                    mask_flag = batch_2d['mask_flag'].to(device=device,
                                                         dtype=torch.float32)

                    loss, loss_dict = criterion(pred=mask_pred,
                                                target=mask_gt,
                                                target_roi_weight=mask_flag)
                    optimizer.zero_grad()
                    loss.backward()
                    torch.nn.utils.clip_grad_value_(model.parameters(), 0.01)
                    optimizer.step()

                    global_step += 1
                    loss_scalar = loss_dict["loss"]
                    loss_focal_scalar = loss_dict["focal_loss"]
                    loss_dice_scalar = loss_dict["dice_loss"]
                    epoch_loss += loss_scalar
                    epoch_loss_focal += loss_focal_scalar
                    epoch_loss_dice += loss_dice_scalar

                    pbar.set_postfix(
                        **{
                            'loss (batch)': loss_scalar,
                            'loss_focal': loss_focal_scalar,
                            'loss_dice': loss_dice_scalar,
                            'global_step': global_step
                        })

                    if (global_step + 1) % (
                            config['write_summary_loss_batch_step']) == 0:
                        logger.info(
                            f"\tBatch: {idx}/{n_batch_2d}, Loss: {loss_scalar}, Focal_loss: {loss_focal_scalar}, Dice_loss: {loss_dice_scalar}"
                        )
                        writer.add_scalar('Loss/train', loss_scalar,
                                          global_step)
                        writer.add_scalar('Loss/train_focal',
                                          loss_focal_scalar, global_step)
                        writer.add_scalar('Loss/train_dice', loss_dice_scalar,
                                          global_step)
                    if (global_step +
                            1) % (config['write_summary_2d_batch_step']) == 0:
                        writer.add_images(
                            'train/images',
                            torch.unsqueeze(img[:, n_channel // 2], 1),
                            global_step)
                        writer.add_images(
                            'train/gt_masks',
                            torch.sum(mask_gt, dim=1, keepdim=True),
                            global_step)
                        writer.add_images(
                            'train/pred_masks',
                            torch.sum(mask_pred[0] > 0, dim=1, keepdim=True) >=
                            1, global_step)
                        writer.add_images(
                            'train/pred_masks_raw',
                            torch.sum(mask_pred[0], dim=1, keepdim=True),
                            global_step)
                pbar.update()

            scheduler.step()
            # log epoch loss
            if (epoch + 1) % config['logging_epoch_step'] == 0:
                writer.add_scalar('lr', optimizer.param_groups[0]['lr'], epoch)
                writer.add_scalar('Loss_epoch/train', epoch_loss, epoch)
                writer.add_scalar('Loss_epoch/train_focal', epoch_loss_focal,
                                  epoch)
                writer.add_scalar('Loss_epoch/train_dice', epoch_loss_dice,
                                  epoch)
                logger.info(
                    f"Epoch: {epoch}/{config['n_epoch']}, Train Loss: {epoch_loss}, Train Loss BCE: {epoch_loss_focal}, Train Loss DSC: {epoch_loss_dice}"
                )

            # validation and save model
            if (epoch + 1) % config['val_model_epoch_step'] == 0:
                val_loss, val_focal_loss, val_dice_loss = val(
                    model, criterion, data_config, n_channel, logger, writer,
                    global_step, device)
                writer.add_scalar('Loss_epoch/val', val_loss, epoch)
                writer.add_scalar('Loss_epoch/val_focal', val_focal_loss,
                                  epoch)
                writer.add_scalar('Loss_epoch/val_dice', val_dice_loss, epoch)
                logger.info(
                    f"Epoch: {epoch}/{config['n_epoch']}, Validation Loss: {val_loss}, Validation Loss Focal: {val_focal_loss}, Validation Loss Dice: {val_dice_loss}"
                )

                os.makedirs(config['ckpt_dir'], exist_ok=True)
                save_checkpoint(model=model,
                                optimizer=optimizer,
                                scheduler=scheduler,
                                epoch=epoch,
                                global_step=global_step,
                                ckpt_dir=config['ckpt_dir'],
                                ckpt_fn=f'ckpt_{date_time}_Epoch_{epoch}.ckpt')

                if best_loss > val_loss:
                    best_loss = val_loss
                    for filename in glob.glob(
                            os.path.join(config['ckpt_dir'], "best_ckpt*")):
                        os.remove(filename)
                    save_checkpoint(
                        model=model,
                        optimizer=optimizer,
                        scheduler=scheduler,
                        epoch=epoch,
                        global_step=global_step,
                        ckpt_dir=config['ckpt_dir'],
                        ckpt_fn=f'best_ckpt_{date_time}_epoch_{epoch}.ckpt')

        if config['freeze_encoder'] and config[
                'unfreeze_encoder_epoch'] is not None:
            if epoch >= int(config['unfreeze_encoder_epoch']):
                unfreeze(model.module.encoder)
                config['unfreeze_encoder_epoch'] = None
                logger.info(f'Unfreeze encoder at {epoch}')
        if config['freeze_decoder'] and config[
                'unfreeze_decoder_epoch'] is not None:
            if epoch >= int(config['unfreeze_decoder_epoch']):
                unfreeze(model.module.decoder)
                config['unfreeze_decoder_epoch'] = None
                logger.info(f'Unfreeze decoder at {epoch}')
    writer.close()
Esempio n. 6
0
def val(model, criterion, data_config, n_channel, logger, writer, global_step,
        device):
    model.eval()
    dataloader_3d = create_loader_3d(data_config, 'val')
    val_loss = 0
    val_focal_loss = 0
    val_dice_loss = 0
    n_batch_3d = len(dataloader_3d)
    with torch.no_grad():
        with tqdm(total=n_batch_3d, desc="Validation execution.",
                  unit='batch') as pbar:
            for idx_3d, batch_3d in enumerate(dataloader_3d):
                dataloader_2d = create_loader_2d(batch_3d, data_config, 'val')
                for idx, batch_2d in enumerate(dataloader_2d):
                    img = batch_2d['img'].to(
                        device=device,
                        dtype=torch.float32)  # [N, n_channel, H, W]
                    mask_gt = batch_2d['mask'].to(
                        device=device, dtype=torch.float32)  # [N, H, W]
                    mask_flag = batch_2d['mask_flag'].to(device=device,
                                                         dtype=torch.float32)
                    mask_pred = model(img)
                    loss, loss_dict = criterion(pred=mask_pred,
                                                target=mask_gt,
                                                target_roi_weight=mask_flag)
                    loss_scalar = loss.item()
                    loss_focal_scalar = loss_dict["focal_loss"]
                    loss_dice_scalar = loss_dict["dice_loss"]
                    val_loss += loss_dict["loss"]
                    val_focal_loss += loss_focal_scalar
                    val_dice_loss += loss_dice_scalar

                    pbar.set_postfix(
                        **{
                            'loss (batch)': loss_scalar,
                            'loss_focal': loss_focal_scalar,
                            'loss_dice': loss_dice_scalar
                        })
                logger.info(
                    f"\tBatch: {idx_3d}/{n_batch_3d}, Loss: {loss_scalar}, Focal_loss: {loss_focal_scalar}, Dice_loss: {loss_dice_scalar}"
                )
                pbar.update()
            writer.add_scalar('Loss/val', loss_scalar, global_step)
            writer.add_scalar('Loss/val_focal', loss_focal_scalar, global_step)
            writer.add_scalar('Loss/val_dice', loss_dice_scalar, global_step)
            writer.add_images('val/images',
                              torch.unsqueeze(img[:, n_channel // 2], 1),
                              global_step)
            writer.add_images('val/gt_masks',
                              torch.sum(mask_gt, dim=1, keepdim=True),
                              global_step)

            writer.add_images(
                'val/pred_masks',
                torch.sum(mask_pred[0] > 0, dim=1, keepdim=True) >= 1,
                global_step)
            writer.add_images('val/pred_masks_raw',
                              torch.sum(mask_pred[0], dim=1, keepdim=True),
                              global_step)
    model.train()
    return val_loss, val_focal_loss, val_dice_loss