예제 #1
0
    def __init__(self,
                 input_modalites,
                 output_channels,
                 base_channel,
                 is_pretrain=False):
        super(CenterLSTMDecoder, self).__init__(input_modalites,
                                                output_channels, base_channel)

        if is_pretrain:
            backbone = init_U_Net(input_modalites,
                                  output_channels,
                                  base_channel,
                                  softmax=False)
            ckp_path = 'best_newdata/UNet-p64-b4-newdata-oriinput_best_model.pth.tar'
            backbone = WrappedModel(backbone)
            checkpoint = torch.load(ckp_path, map_location=torch.device('cpu'))
            for param in backbone.parameters():
                param.requires_grad = False
            backbone.load_state_dict(checkpoint['model_state_dict'])

            self.up_sample_1 = backbone.module.up_sample_1
            self.up_sample_2 = backbone.module.up_sample_2
            self.up_sample_3 = backbone.module.up_sample_3
            self.up_conv1 = backbone.module.up_conv1
            self.up_conv2 = backbone.module.up_conv2
            self.up_conv3 = backbone.module.up_conv3

            self.out = backbone.module.out
            for param in self.out.parameters():
                param.requires_grad = True
            nn.init.kaiming_normal_(self.out.weight,
                                    mode='fan_out',
                                    nonlinearity='leaky_relu')
예제 #2
0
    def __init__(self,
                 input_modalites,
                 output_channels,
                 base_channel,
                 is_pretrain=True):
        super(CenterLSTMEncoder, self).__init__(input_modalites,
                                                output_channels, base_channel)

        if is_pretrain:
            backbone = init_U_Net(input_modalites,
                                  output_channels,
                                  base_channel,
                                  softmax=False)
            ckp_path = 'best_newdata/UNet-p64-b4-newdata-oriinput_best_model.pth.tar'
            backbone = WrappedModel(backbone)
            checkpoint = torch.load(ckp_path, map_location=torch.device('cpu'))
            for param in backbone.parameters():
                param.requires_grad = False
            backbone.load_state_dict(checkpoint['model_state_dict'])

            self.down_conv1 = backbone.module.down_conv1
            self.down_conv2 = backbone.module.down_conv2
            self.down_conv3 = backbone.module.down_conv3
            self.down_sample_1 = backbone.module.down_sample_1
            self.down_sample_2 = backbone.module.down_sample_2
            self.down_sample_3 = backbone.module.down_sample_3
예제 #3
0
def predict_use(args):

    model_name = args.model_name
    patient_path = args.patient_path

    config_file = 'config.yaml'
    cfg = load_config(config_file)
    input_modalites = int(cfg['PARAMETERS']['input_modalites'])
    output_channels = int(cfg['PARAMETERS']['output_channels'])
    base_channels = int(cfg['PARAMETERS']['base_channels'])
    patience = int(cfg['PARAMETERS']['patience'])

    ROOT = cfg['PATH']['root']
    best_dir = cfg['PATH']['best_model_path']
    best_model_dir = os.path.join(ROOT, best_dir)
    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

    # load best trained model
    net = init_U_Net(input_modalites, output_channels, base_channels)
    net.to(device)

    optimizer = optim.Adam(net.parameters())
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)
    ckp_path = os.path.join(best_model_dir, model_name + '_best_model.pth.tar')
    net, _, _, _, _, _ = load_ckp(ckp_path, net, optimizer, scheduler)

    # predict
    predict(net, model_name, patient_path, ROOT, save_mask=True)
예제 #4
0
    def __init__(self,
                 input_modalites,
                 output_channels,
                 base_channel,
                 num_layers,
                 num_connects,
                 pad_method='pad',
                 conv_type='plain',
                 softmax=True,
                 is_pretrain=True):
        super(ShortcutLSTMBody,
              self).__init__(input_modalites, output_channels, base_channel,
                             pad_method, softmax)

        self.input_modalites = input_modalites
        self.output_channels = output_channels
        self.base_channel = base_channel
        self.pad_method = pad_method
        self.softmax = softmax
        self.num_layers = num_layers
        self.conv_type = conv_type
        self.num_connects = num_connects

        if is_pretrain:
            backbone = init_U_Net(input_modalites,
                                  output_channels,
                                  base_channel,
                                  softmax=False)
            ckp_path = 'best_newdata/UNet-p64-b4-newdata-oriinput_best_model.pth.tar'
            backbone = WrappedModel(backbone)
            checkpoint = torch.load(ckp_path, map_location=torch.device('cpu'))
            for param in backbone.parameters():
                param.requires_grad = False
            backbone.load_state_dict(checkpoint['model_state_dict'])

            self.down_conv1 = backbone.module.down_conv1
            self.down_conv2 = backbone.module.down_conv2
            self.down_conv3 = backbone.module.down_conv3
            self.down_sample_1 = backbone.module.down_sample_1
            self.down_sample_2 = backbone.module.down_sample_2
            self.down_sample_3 = backbone.module.down_sample_3
            self.bridge = backbone.module.bridge
            self.up_sample_1 = backbone.module.up_sample_1
            self.up_sample_2 = backbone.module.up_sample_2
            self.up_sample_3 = backbone.module.up_sample_3
            self.up_conv1 = backbone.module.up_conv1
            self.up_conv2 = backbone.module.up_conv2
            self.up_conv3 = backbone.module.up_conv3
            # self.out  = backbone.module.out

        self.up_conv1 = nn.Sequential(*list(self.up_conv1.block)[:3])
        self.up_conv2 = nn.Sequential(*list(self.up_conv2.block)[:3])
        self.up_conv3 = nn.Sequential(*list(self.up_conv3.block)[:3])
예제 #5
0
def train(args):

    torch.cuda.manual_seed(1)
    torch.manual_seed(1)

    # user defined
    model_name = args.model_name
    model_loss_fn = args.loss_fn

    config_file = 'config.yaml'

    config = load_config(config_file)
    data_root = config['PATH']['data_root']
    labels = config['PARAMETERS']['labels']
    root_path = config['PATH']['root']
    model_dir = config['PATH']['model_path']
    best_dir = config['PATH']['best_model_path']

    data_class = config['PATH']['data_class']
    input_modalites = int(config['PARAMETERS']['input_modalites'])
    output_channels = int(config['PARAMETERS']['output_channels'])
    base_channel = int(config['PARAMETERS']['base_channels'])
    crop_size = int(config['PARAMETERS']['crop_size'])
    batch_size = int(config['PARAMETERS']['batch_size'])
    epochs = int(config['PARAMETERS']['epoch'])
    is_best = bool(config['PARAMETERS']['is_best'])
    is_resume = bool(config['PARAMETERS']['resume'])
    patience = int(config['PARAMETERS']['patience'])
    ignore_idx = int(config['PARAMETERS']['ignore_index'])
    early_stop_patience = int(config['PARAMETERS']['early_stop_patience'])

    # build up dirs
    model_path = os.path.join(root_path, model_dir)
    best_path = os.path.join(root_path, best_dir)
    intermidiate_data_save = os.path.join(root_path, 'train_data', model_name)
    train_info_file = os.path.join(intermidiate_data_save,
                                   '{}_train_info.json'.format(model_name))
    log_path = os.path.join(root_path, 'logfiles')

    if not os.path.exists(model_path):
        os.mkdir(model_path)
    if not os.path.exists(best_path):
        os.mkdir(best_path)
    if not os.path.exists(intermidiate_data_save):
        os.makedirs(intermidiate_data_save)
    if not os.path.exists(log_path):
        os.mkdir(log_path)

    log_name = model_name + '_' + config['PATH']['log_file']
    logger = logfile(os.path.join(log_path, log_name))
    logger.info('Dataset is loading ...')
    # split dataset
    dir_ = os.path.join(data_root, data_class)
    data_content = train_split(dir_)

    # load training set and validation set
    train_set = data_loader(data_content=data_content,
                            key='train',
                            form='LGG',
                            crop_size=crop_size,
                            batch_size=batch_size,
                            num_works=8)
    n_train = len(train_set)
    train_loader = train_set.load()

    val_set = data_loader(data_content=data_content,
                          key='val',
                          form='LGG',
                          crop_size=crop_size,
                          batch_size=batch_size,
                          num_works=8)

    logger.info('Dataset loading finished!')

    n_val = len(val_set)
    nb_batches = np.ceil(n_train / batch_size)
    n_total = n_train + n_val
    logger.info(
        '{} images will be used in total, {} for trainning and {} for validation'
        .format(n_total, n_train, n_val))

    net = init_U_Net(input_modalites, output_channels, base_channel)

    device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
    if torch.cuda.device_count() > 1:
        logger.info('{} GPUs available.'.format(torch.cuda.device_count()))
        net = nn.DataParallel(net)

    net.to(device)

    if model_loss_fn == 'Dice':
        criterion = DiceLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'CrossEntropy':
        criterion = CrossEntropyLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'FocalLoss':
        criterion = FocalLoss(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_CE':
        criterion = Dice_CE(labels=labels, ignore_index=ignore_idx)
    elif model_loss_fn == 'Dice_FL':
        criterion = Dice_FL(labels=labels, ignore_index=ignore_idx)
    else:
        raise NotImplementedError()

    optimizer = optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-5)
    scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                     verbose=True,
                                                     patience=patience)

    net, optimizer = amp.initialize(net, optimizer, opt_level='O1')

    min_loss = float('Inf')
    early_stop_count = 0
    global_step = 0
    start_epoch = 0
    start_loss = 0
    train_info = {
        'train_loss': [],
        'val_loss': [],
        'BG_acc': [],
        'NET_acc': [],
        'ED_acc': [],
        'ET_acc': []
    }

    if is_resume:
        try:
            ckp_path = os.path.join(model_dir,
                                    '{}_model_ckp.pth.tar'.format(model_name))
            net, optimizer, scheduler, start_epoch, min_loss, start_loss = load_ckp(
                ckp_path, net, optimizer, scheduler)

            # open previous training records
            with open(train_info_file) as f:
                train_info = json.load(f)

            logger.info(
                'Training loss from last time is {}'.format(start_loss) +
                '\n' +
                'Mininum training loss from last time is {}'.format(min_loss))

        except:
            logger.warning(
                'No checkpoint available, strat training from scratch')

    # start training
    for epoch in range(start_epoch, epochs):

        # setup to train mode
        net.train()
        running_loss = 0
        dice_coeff_bg = 0
        dice_coeff_net = 0
        dice_coeff_ed = 0
        dice_coeff_et = 0

        logger.info('Training epoch {} will begin'.format(epoch + 1))

        with tqdm(total=n_train,
                  desc=f'Epoch {epoch+1}/{epochs}',
                  unit='patch') as pbar:

            for i, data in enumerate(train_loader, 0):
                images, segs = data['image'].to(device), data['seg'].to(device)

                # zero the parameter gradients
                optimizer.zero_grad()

                outputs = net(images)

                loss = criterion(outputs, segs)
                # loss.backward()
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()

                optimizer.step()

                # save the output at the begining of each epoch to visulize it
                if i == 0:
                    in_images = images.detach().cpu().numpy()[:, 0, ...]
                    in_segs = segs.detach().cpu().numpy()
                    in_pred = outputs.detach().cpu().numpy()
                    heatmap_plot(image=in_images,
                                 mask=in_segs,
                                 pred=in_pred,
                                 name=model_name,
                                 epoch=epoch + 1)

                running_loss += loss.detach().item()
                dice_score = dice_coe(outputs.detach().cpu(),
                                      segs.detach().cpu())
                dice_coeff_bg += dice_score['BG']
                dice_coeff_ed += dice_score['ED']
                dice_coeff_et += dice_score['ET']
                dice_coeff_net += dice_score['NET']

                # show progress bar
                pbar.set_postfix(
                    **{
                        'Training loss': loss.detach().item(),
                        'Training (avg) accuracy': dice_score['avg']
                    })
                pbar.update(images.shape[0])

                global_step += 1
                if global_step % nb_batches == 0:
                    # validate
                    net.eval()
                    val_loss, val_acc = validation(net, val_set, criterion,
                                                   device, batch_size)

        train_info['train_loss'].append(running_loss / nb_batches)
        train_info['val_loss'].append(val_loss)
        train_info['BG_acc'].append(dice_coeff_bg / nb_batches)
        train_info['NET_acc'].append(dice_coeff_net / nb_batches)
        train_info['ED_acc'].append(dice_coeff_ed / nb_batches)
        train_info['ET_acc'].append(dice_coeff_et / nb_batches)

        # save bast trained model
        scheduler.step(running_loss / nb_batches)

        if min_loss > val_loss:
            min_loss = val_loss
            is_best = True
            early_stop_count = 0
        else:
            is_best = False
            early_stop_count += 1

        state = {
            'epoch': epoch + 1,
            'model_state_dict': net.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'loss': running_loss / nb_batches,
            'min_loss': min_loss
        }
        verbose = save_ckp(state,
                           is_best,
                           early_stop_count=early_stop_count,
                           early_stop_patience=early_stop_patience,
                           save_model_dir=model_path,
                           best_dir=best_path,
                           name=model_name)

        logger.info('The average training loss for this epoch is {}'.format(
            running_loss / (np.ceil(n_train / batch_size))))
        logger.info(
            'Validation dice loss: {}; Validation (avg) accuracy: {}'.format(
                val_loss, val_acc))
        logger.info('The best validation loss till now is {}'.format(min_loss))

        # save the training info every epoch
        logger.info('Writing the training info into file ...')
        with open(train_info_file, 'w') as fp:
            json.dump(train_info, fp)

        loss_plot(train_info_file, name=model_name)

        if verbose:
            logger.info(
                'The validation loss has not improved for {} epochs, training will stop here.'
                .format(early_stop_patience))
            break

    logger.info('finish training!')