Exemplo n.º 1
0
    def __init__(self,
                 batch_norm_mode,
                 depth,
                 model_root_channel=8,
                 img_size=256,
                 batch_size=20,
                 n_channel=1,
                 n_class=2):

        self.drop_rate = tf.placeholder(tf.float32)
        self.training = tf.placeholder(tf.bool)

        self.batch_size = batch_size
        self.model_channel = model_root_channel
        self.batch_mode = batch_norm_mode
        self.depth_n = depth

        self.X = tf.placeholder(tf.float32,
                                [None, img_size, img_size, n_channel],
                                name='X')
        self.Y = tf.placeholder(tf.float32,
                                [None, img_size, img_size, n_class],
                                name='Y')

        self.logits = self.neural_net()

        self.foreground_predicted, self.background_predicted = tf.split(
            tf.nn.softmax(self.logits), [1, 1], 3)

        self.foreground_truth, self.background_truth = tf.split(
            self.Y, [1, 1], 3)

        with tf.name_scope('Loss'):
            # # Cross_Entropy
            # self.loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=self.logits, labels=self.Y))

            # # Dice_Loss
            self.loss = utils.dice_loss(output=self.logits, target=self.Y)

            # # Focal_Loss
            # self.loss=utils.focal_loss(output=self.logits, target=self.Y, use_class=True, gamma=2, smooth=1e-8)

        with tf.name_scope('Metrics'):
            self.accuracy = utils.mean_iou(self.foreground_predicted,
                                           self.foreground_truth)

        # TB
        tf.summary.scalar('loss', self.loss)
        tf.summary.scalar('accuracy', self.accuracy)
Exemplo n.º 2
0
def main():
    model = factory('unet')
    from utils import dice_loss

    inputs = torch.randn(8, 3, 256, 256)
    labels = torch.LongTensor(8, 256, 256).random_(1).type(torch.FloatTensor)

    model = model.cuda().train()
    x = torch.autograd.Variable(inputs).cuda()
    y = torch.autograd.Variable(labels).cuda()
    logits = model.forward(x)

    loss = dice_loss(logits, y)
    loss.backward()

    print(type(model))
    print(model)

    print('logits')
    print(logits)
Exemplo n.º 3
0
def train(NetG, NetD, optimizerG, optimizerD, dataloader, epoch):
    total_dice = 0
    total_g_loss = 0
    total_g_loss_dice = 0
    total_g_loss_bce = 0
    total_d_loss = 0
    total_d_loss_penalty = 0
    NetG.train()
    NetD.train()

    for i, data in enumerate(dataloader, 1):
        # train D
        optimizerD.zero_grad()
        NetD.zero_grad()
        for p in NetG.parameters():
            p.requires_grad = False
        for p in NetD.parameters():
            p.requires_grad = True

        input, target = Variable(data[0]), Variable(data[1])
        input = input.float()
        target = target.float()

        if use_cuda:
            input = input.cuda()
            target = target.cuda()

        output = NetG(input)
        output = F.sigmoid(output)
        output = output.detach()

        input_img = input.clone()
        output_masked = input_img * output
        if use_cuda:
            output_masked = output_masked.cuda()

        result = NetD(output_masked)

        target_masked = input_img * target
        if use_cuda:
            target_masked = target_masked.cuda()

        target_D = NetD(target_masked)
        loss_mac = -torch.mean(torch.abs(result - target_D))
        loss_mac.backward()

        # D net gradient_penalty
        batch_size = target_masked.size(0)
        gradient_penalty = utils.calc_gradient_penalty(NetD, target_masked,
                                                       output_masked,
                                                       batch_size, use_cuda,
                                                       input.shape)
        gradient_penalty.backward()
        optimizerD.step()

        # train G
        optimizerG.zero_grad()
        NetG.zero_grad()
        for p in NetG.parameters():
            p.requires_grad = True
        for p in NetD.parameters():
            p.requires_grad = False

        output = NetG(input)
        output = F.sigmoid(output)

        target_dice = target.view(-1).long()
        output_dice = output.view(-1)
        loss_dice = utils.dice_loss(output_dice, target_dice)

        output_masked = input_img * output
        if use_cuda:
            output_masked = output_masked.cuda()
        result = NetD(output_masked)

        target_G = NetD(target_masked)
        loss_G = torch.mean(torch.abs(result - target_G))
        loss_G_joint = loss_G + loss_dice
        loss_G_joint.backward()
        optimizerG.step()

        total_dice += 1 - loss_dice.data[0]
        total_g_loss += loss_G_joint.data[0]
        total_g_loss_dice += loss_dice.data[0]
        total_g_loss_bce += loss_G.data[0]
        total_d_loss += loss_mac.data[0]
        total_d_loss_penalty += gradient_penalty.data[0]

    for p in NetG.parameters():
        p.requires_grad = True
    for p in NetD.parameters():
        p.requires_grad = True

    size = len(dataloader)

    epoch_dice = total_dice / size
    epoch_g_loss = total_g_loss / size
    epoch_g_loss_dice = total_g_loss_dice / size
    epoch_g_loss_bce = total_g_loss_bce / size

    epoch_d_loss = total_d_loss / size
    epoch_d_loss_penalty = total_d_loss_penalty / size

    print_format = [
        epoch, conf.epochs, epoch_dice * 100, epoch_g_loss, epoch_g_loss_dice,
        epoch_g_loss_bce, epoch_d_loss, epoch_d_loss_penalty
    ]
    print('===> Training step {}/{} \tepoch_dice: {:.5f}'
          '\tepoch_g_loss: {:.5f} \tepoch_g_loss_dice: {:.5f}'
          '\tepoch_g_loss_bce: {:.5f} \tepoch_d_loss: {:.5f}'
          '\tepoch_d_loss_penalty: {:.5f}'.format(*print_format))
Exemplo n.º 4
0
def final_multiscale_roi_align(model, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25, viz= False):
        # Steps
    '''
    Generate all the patches as 28x28 from the MRI scan. do it with anchor boxes. write a Transfrom for that.
        - for an scan generate 10 RoIs.
        - the generator will return this:- MRI_scan224, MRI_label224, plus (all anchor boxes of MRI_scan28x28, MRI_label28x28)
        - if the sample has some lesion, return all the RoIs with that lesion. return just the (x1, y1, x2, y2) of the boxes in 224x224 map.
        - if the sample has no lesion, then return 10 RoIs of no lesion zone.
    Then perform these new set of actions on the sub-level data transform:-
        - it has to take the 224x224 tensor, and the rois, and then do the roi align to generate these level of feature maps.
        - now, view(-1, m, n) and randomize all the samples, for all (m, n) maps levels.
        - it has to run a simple algorithm to get the class as 0 or 1 for every patch.

    Second, get RoI maps for the same 28x28 roi from the feature maps of the CNN using RoI align. and by passing through the deconv nets.
        - so, the model() nn.module has perform all this.
        - it has to run deconv nets as pytorch.nn modules for these levels of patches dims to result in uniform 28x28 maps.
        - concat all the 28x28 predicted masks from these feature levels, make one small 3x3 or 3x3 conv and 1x1 conv until it ends up here.
        - it has to return 28x28 predictions for all feature levels individually plus the max class voting result from these preds, as one mask
          plus the classification head

    Third, frame the loss function with the classifier head and the segmentor head.
        - train the classifier for all samples.
        - run a simple algorithm to collect only those samples with non-zero lesion based on the patch classifier label.
        - run piecewise loss for every patch mask to prediction.
          Also, double it up with a secondary, loss function.
    '''
    #out dirs
    base_dir = Path.cwd() / 'outputs' / 'single_scale_roi_align'
    output_tracking_dir = base_dir / 'output_tracking'
    logs_dir = base_dir / 'logs'
    model_dir = base_dir / 'model'

    model = model.to(device)

    upsampler = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)

    logs_dir.mkdir(parents= True, exist_ok= True)
    model_dir.mkdir(parents= True, exist_ok=True)
    output_tracking_dir.mkdir(parents=True, exist_ok=True)

    since = time.time()
    PATH = str(model_dir / (model.name+'.pth'))
    epo = 1

    if Path(PATH).is_file():
        checkpoint = torch.load(PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epo = checkpoint['epoch']
        loss = checkpoint['loss']
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        print('Resuming from epoch ' + str(epo) + ', LOSS: ', loss)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 3.0

    logs_ptr = open(str(logs_dir/ 'train_logs'), 'a')

    # pdb.set_trace()
    for epoch in range(epo, epo + num_epochs):
        epoch_str = 'Epoch {}/{}'.format(epoch, epo + num_epochs - 1) + '\n\n'
        print(epoch_str)
        logs_ptr.write(epoch_str)

        print('-' * 10)

        try:
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_softmax = 0.0
                running_dice = 0.0

                # Iterate over data.
                times = 0

                for mini_batch, (inputs, label224, label28) in enumerate(dataloaders[phase]):

                    inputs = inputs.to(device)

                    # labels size is (batch_size, 1, 224, 224)
                    label28 = label28.to(device)

                    label224 = label224.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):

                        log_softmax_outputs28 = model(inputs)  # shape of pred28 is (batch_size, 2, 28, 28)

                        softmax_loss = F.nll_loss(log_softmax_outputs28, label28.round().squeeze().long(),
                                                  weight=class_weights)

                        softmax_outputs28 = torch.exp(log_softmax_outputs28)
                        torch_pred28_prob = get_prob_map28(softmax_outputs28)
                        torch_pred224_prob = upsampler(torch_pred28_prob)

                        rounded_pred224_prob_for_dice = torch.round(torch_pred224_prob)
                        # return format is (batch_size, 1, 224, 224)

                        dice_l = dice_loss(input=rounded_pred224_prob_for_dice, target=label224)

                        # dice_l = dice_loss(input=outputs28, target=mask28)

                        total_loss = 0.7 * dice_l + 0.3 * softmax_loss

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            total_loss.backward()
                            optimizer.step()

                    if phase == 'train':
                        step_str = '{} Step {}- Loss: {:.4f}, Dice Loss: {:.4f}, Softmax Loss: {:.4f}'\
                              .format(phase, mini_batch + 1,total_loss, dice_l, softmax_loss)

                        print(step_str)

                        logs_ptr.write(step_str+'\n')

                    if phase == 'val' and viz:
                        for item in range(label28.size(0)):
                            # get the path for saving the intermediate outputs
                            epoch_tracking_path = output_tracking_dir / str(epoch)

                            if not epoch_tracking_path.is_dir():
                                epoch_tracking_path.mkdir(parents=True, exist_ok=False)

                            actual_predicted(label224[item][0].numpy(),
                                             rounded_pred224_prob_for_dice[item][0].detach().numpy(),
                                             str(epoch_tracking_path / (str(mini_batch * label28.size(0) + item) + '.jpg') )
                                             )

                    # statistics
                    # running_loss += step_loss.item() * inputs.size(0)
                    running_dice += dice_l.item() * inputs.size(0)
                    running_softmax += softmax_loss.item() * inputs.size(0)

                    # times+=1
                    # if times==2:
                    #     break

                # end of an epoch
                # pdb.set_trace()

                # epoch_loss = running_loss / dataset_sizes[phase]
                epoch_dice_l = running_dice / dataset_sizes[phase]
                epoch_softmax = running_softmax / dataset_sizes[phase]
                epoch_loss = epoch_dice_l + epoch_softmax

                if phase == 'train':
                    scheduler.step()

                loss_str = '\n{} Epoch {}: TotalLoss: {:.4f}   SoftmaxLoss: {:.4f} Dice Loss: {:.4f} \n'.format(
                    phase, epoch, epoch_loss, epoch_softmax, epoch_dice_l) + '\n'
                print(loss_str)

                logs_ptr.write(loss_str + '\n')

                # deep copy the model
                if phase == 'val' and epoch_loss > best_loss:
                    print('Val Dice better than Best Dice')
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

        except:
            # save model
            save_model(epoch, best_model_wts, optimizer, scheduler, epoch_loss, PATH)
            exit(0)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val DICE: {:4f}'.format(best_loss))

    # save model
    save_model(num_epochs,
               best_model_wts,
               optimizer,
               scheduler, loss, PATH)
Exemplo n.º 5
0
def experiment3(model, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25, viz = False):
    #out dirs
    base_dir = Path.cwd() / 'outputs' / 'experiment3'
    output_tracking_dir = base_dir / 'output_tracking'
    logs_dir = base_dir / 'logs'
    model_dir = base_dir / 'model'

    upsampler = nn.Upsample(scale_factor=8, mode='bilinear', align_corners=True)

    model = model.to(device)

    logs_dir.mkdir(parents= True, exist_ok= True)
    model_dir.mkdir(parents= True, exist_ok=True)
    output_tracking_dir.mkdir(parents=True, exist_ok=True)

    since = time.time()
    PATH = str(model_dir / (model.name+'.pth'))
    epo = 1

    if Path(PATH).is_file():
        checkpoint = torch.load(PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epo = checkpoint['epoch']
        loss = checkpoint['loss']
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        print('Resuming from epoch ' + str(epo) + ', LOSS: ', loss)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 3.0

    logs_ptr = open(str(logs_dir/ 'train_logs'), 'a')

    # pdb.set_trace()
    for epoch in range(epo, epo + num_epochs):
        epoch_str = 'Epoch {}/{}'.format(epoch, epo + num_epochs - 1) + '\n\n'
        print(epoch_str)
        logs_ptr.write(epoch_str)

        print('-' * 10)

        try:
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_softmax = 0.0
                running_dice = 0.0

                # Iterate over data.
                times = 0

                for mini_batch, (inputs, label224, _) in enumerate(dataloaders[phase]):

                    inputs = inputs.to(device)

                    # labels size is (batch_size, 1, 224, 224)
                    label224 = label224.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):

                        log_softmax_outputs224 = model(inputs)  # shape of pred224 is (batch_size, 2, 224, 224)

                        softmax_loss = F.nll_loss(log_softmax_outputs224, label224.squeeze().long(),
                                                  weight=class_weights)

                        softmax_outputs224 = torch.exp(log_softmax_outputs224)

                        _, pred224_argmax = torch.max(softmax_outputs224, dim=1, keepdim=True)  # (batch_size, 1, 28,28)
                        pred224_argmax = pred224_argmax.float()

                        dice_l = dice_loss(input=pred224_argmax, target=label224)

                        # dice_l = dice_loss(input=outputs28, target=mask28)

                        total_loss = 0.9 * dice_l + 0.1 * softmax_loss

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            total_loss.backward()
                            optimizer.step()

                    if phase == 'train':
                        step_str = '{} Step {}- Loss: {:.4f}, Dice Loss: {:.4f}, Softmax Loss: {:.4f}'\
                              .format(phase, mini_batch + 1,total_loss, dice_l, softmax_loss)

                        print(step_str)

                        logs_ptr.write(step_str+'\n')

                    if phase == 'val' and viz:
                        for item in range(label224.size(0)):
                            # get the path for saving the intermediate outputs
                            epoch_tracking_path = output_tracking_dir / str(epoch)

                            if not epoch_tracking_path.is_dir():
                                epoch_tracking_path.mkdir(parents=True, exist_ok=False)

                            actual_predicted(label224[item][0].numpy(),
                                             pred224_argmax[item][0].detach().numpy(),
                                             str(epoch_tracking_path / (str(mini_batch * label224.size(0) + item) + '.jpg') )
                                             )

                    # statistics
                    # running_loss += step_loss.item() * inputs.size(0)
                    running_dice += dice_l.item() * inputs.size(0)
                    running_softmax += softmax_loss.item() * inputs.size(0)

                    # times+=1
                    # if times==2:
                    #     break

                # end of an epoch
                # pdb.set_trace()

                # epoch_loss = running_loss / dataset_sizes[phase]
                epoch_dice_l = running_dice / dataset_sizes[phase]
                epoch_softmax = running_softmax / dataset_sizes[phase]
                epoch_loss = epoch_dice_l + epoch_softmax

                if phase == 'train':
                    scheduler.step()

                loss_str = '\n{} Epoch {}: TotalLoss: {:.4f}   SoftmaxLoss: {:.4f} Dice Loss: {:.4f} \n'.format(
                    phase, epoch, epoch_loss, epoch_softmax, epoch_dice_l) + '\n'
                print(loss_str)

                logs_ptr.write(loss_str + '\n')

                # deep copy the model
                if phase == 'val' and epoch_loss > best_loss:
                    print('Val Dice better than Best Dice')
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

        except:
            # save model
            save_model(epoch, best_model_wts, optimizer, scheduler, epoch_loss, PATH)
            exit(0)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val DICE: {:4f}'.format(best_loss))

    # save model
    save_model(num_epochs,
               best_model_wts,
               optimizer,
               scheduler, epoch_loss, PATH)
Exemplo n.º 6
0
def experiment1(model, optimizer, scheduler, dataloaders, dataset_sizes, num_epochs=25, viz = False):
    #model to CUDA
    model = model.to(device)

    #out dirs
    base_dir = Path.cwd() / 'outputs' / 'experiment1'
    output_tracking_dir = base_dir / 'output_tracking'
    logs_dir = base_dir / 'logs'
    model_dir = base_dir / 'model'

    logs_dir.mkdir(parents= True, exist_ok= True)
    model_dir.mkdir(parents= True, exist_ok=True)
    output_tracking_dir.mkdir(parents=True, exist_ok=True)

    since = time.time()
    PATH = str(model_dir / (model.name+'.pth'))
    epo = 1

    if Path(PATH).is_file():
        checkpoint = torch.load(PATH)
        model.load_state_dict(checkpoint['model_state_dict'])
        optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        epo = checkpoint['epoch']
        loss = checkpoint['loss']
        scheduler.load_state_dict(checkpoint['scheduler_state_dict'])

        print('Resuming from epoch ' + str(epo) + ', LOSS: ', loss)

    best_model_wts = copy.deepcopy(model.state_dict())
    best_loss = 3.0

    logs_ptr = open(str(logs_dir/ 'train_logs'), 'a')

    # pdb.set_trace()
    for epoch in range(epo, epo + num_epochs):
        epoch_str = 'Epoch {}/{}'.format(epoch, epo + num_epochs - 1) + '\n\n'
        print(epoch_str)
        logs_ptr.write(epoch_str)

        print('-' * 10)

        try:
            # Each epoch has a training and validation phase
            for phase in ['train', 'val']:
                if phase == 'train':
                    model.train()  # Set model to training mode
                else:
                    model.eval()  # Set model to evaluate mode

                running_reg = 0.0
                running_dice = 0.0

                # Iterate over data.
                times = 0

                for mini_batch, (inputs, label224, label28) in enumerate(dataloaders[phase]):

                    inputs = inputs.to(device)

                    # labels size is (batch_size, 1, 224, 224)
                    label28 = label28.to(device)

                    # zero the parameter gradients
                    optimizer.zero_grad()

                    # forward
                    # track history if only in train
                    with torch.set_grad_enabled(phase == 'train'):

                        log_softmax_outputs28 = model(inputs)  # shape of pred28 is (batch_size, 2, 28, 28)

                        softmax_outputs28 = torch.exp(log_softmax_outputs28)
                        output28_prob = get_prob_map28(softmax_outputs28)

                        reg_loss = torch.mean(
                            torch.sum(-torch.log(1.0 - torch.abs(output28_prob - label28)), dim=[1, 2, 3])
                        )/1000.0

                        dice_l = dice_loss(input=torch.round(output28_prob), target=torch.round(label28))

                        total_loss = reg_loss + 0.5*dice_l

                        # backward + optimize only if in training phase
                        if phase == 'train':
                            total_loss.backward()
                            optimizer.step()

                    if phase == 'train':
                        step_str = '{} Step {}- Loss: {:.4f}, Dice Loss: {:.4f}, Reg Loss: {:.4f}'\
                            .format(phase, mini_batch + 1,total_loss, dice_l, reg_loss)
                        print(step_str)

                        logs_ptr.write(step_str+'\n')

                    if phase == 'val' and viz:
                        output28_prob = output28_prob.cpu()
                        label28 = label28.cpu()

                        for item in range(label28.size(0)):
                            expanded_output28_prob = expand_mask([[0, 0, 224, 224]],
                                                                 output28_prob[item].detach().numpy(),
                                                                 (224, 224))
                            expanded_label28 = expand_mask([[0, 0, 224, 224]], label28[item].detach().numpy(),
                                                           (224, 224))

                            epoch_tracking_path = output_tracking_dir / str(epoch)
                            if not epoch_tracking_path.is_dir():
                                epoch_tracking_path.mkdir(parents=True, exist_ok=False)

                            actual_predicted(expanded_label28[0], expanded_output28_prob[0],
                                             str(epoch_tracking_path / (str(mini_batch*label28.size(0) +item) + '.jpg') ) )

                    # statistics
                    # running_loss += step_loss.item() * inputs.size(0)
                    running_dice += dice_l.item() * inputs.size(0)
                    running_reg += reg_loss.item() * inputs.size(0)

                    # times+=1
                    # if times==2:
                    #     break

                # end of an epoch
                # pdb.set_trace()

                # epoch_loss = running_loss / dataset_sizes[phase]
                epoch_dice_l = running_dice / dataset_sizes[phase]
                epoch_reg_loss = running_reg / dataset_sizes[phase]
                epoch_loss = epoch_dice_l + epoch_reg_loss

                if phase == 'train':
                    scheduler.step()

                loss_str = '\n{} Epoch {}: TotalLoss: {:.4f}   RegLoss: {:.4f} Dice Loss: {:.4f} \n'.format(
                    phase, epoch, epoch_loss, epoch_reg_loss, epoch_dice_l) + '\n'
                print(loss_str)

                logs_ptr.write(loss_str + '\n')

                # deep copy the model
                if phase == 'val' and epoch_loss >= best_loss:
                    print('Val Dice better than Best Dice')
                    best_loss = epoch_loss
                    best_model_wts = copy.deepcopy(model.state_dict())

        except:
            # save model
            save_model(epoch, best_model_wts, optimizer, scheduler, epoch_loss, PATH)
            exit(0)

        print()

    time_elapsed = time.time() - since
    print('Training complete in {:.0f}m {:.0f}s'.format(
        time_elapsed // 60, time_elapsed % 60))
    print('Best val DICE: {:4f}'.format(best_loss))

    # save model
    save_model(num_epochs,
               best_model_wts,
               optimizer,
               scheduler, epoch_loss, PATH)
Exemplo n.º 7
0
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('-e',
                        '--epochs',
                        dest='epochs',
                        default=20,
                        type=int,
                        help='number of epochs')
    parser.add_argument('-b',
                        '--batch_size',
                        dest='batch_size',
                        default=40,
                        type=int,
                        help='batch size')
    parser.add_argument('-s',
                        '--image_size',
                        dest='image_size',
                        default=256,
                        type=int,
                        help='input image size')
    parser.add_argument('-lr',
                        '--learning_rate',
                        dest='lr',
                        default=0.0001,
                        type=float,
                        help='learning rate')
    parser.add_argument('-wd',
                        '--weight_decay',
                        dest='weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('-lrs',
                        '--learning_rate_step',
                        dest='lr_step',
                        default=10,
                        type=int,
                        help='learning rate step')
    parser.add_argument('-lrg',
                        '--learning_rate_gamma',
                        dest='lr_gamma',
                        default=0.5,
                        type=float,
                        help='learning rate gamma')
    parser.add_argument(
        '-m',
        '--model',
        dest='model',
        default='fpn',
    )
    parser.add_argument('-w',
                        '--weight_bce',
                        default=0.5,
                        type=float,
                        help='weight BCE loss')
    parser.add_argument('-l',
                        '--load',
                        dest='load',
                        default=False,
                        help='load file model')
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.7,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='./output',
                        help='dir to save log and models')
    args = parser.parse_args()

    os.makedirs(args.output_dir, exist_ok=True)
    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)


#     net = UNet() # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
#     net = smp.FPN('mobilenet_v2', encoder_weights='imagenet', classes=2)
    net = smp.FPN('se_resnet50', encoder_weights='imagenet', classes=2)
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.load:
        net.load_state_dict(torch.load(args.load))
    logger.info('Model type: {}'.format(net.__class__.__name__))

    net.to(device)

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y),
                              (1. - args.weight_bce) * dice_loss(x, y))
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \
        if args.lr_step > 0 else None

    train_transforms = Compose([
        Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
        Flip(p=0.05),
        RandomRotate(),
        Pad(max_size=0.6, p=0.25),
        Resize(size=(args.image_size, args.image_size), keep_aspect=True),
        ScaleToZeroOne(),
    ])
    val_transforms = Compose([
        Resize(size=(args.image_size, args.image_size)),
        ScaleToZeroOne(),
    ])

    train_dataset = DetectionDataset(args.data_path,
                                     os.path.join(args.data_path,
                                                  'train_mask.json'),
                                     transforms=train_transforms)
    val_dataset = DetectionDataset(args.data_path,
                                   None,
                                   transforms=val_transforms)

    train_size = int(len(train_dataset) * args.val_split)
    val_dataset.image_names = train_dataset.image_names[train_size:]
    val_dataset.mask_names = train_dataset.mask_names[train_size:]
    train_dataset.image_names = train_dataset.image_names[:train_size]
    train_dataset.mask_names = train_dataset.mask_names[:train_size]
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True,
                                  drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False,
                                drop_last=False)
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              optimizer,
              criterion,
              scheduler,
              train_dataloader,
              val_dataloader,
              logger=logger,
              args=args,
              device=device)
    except KeyboardInterrupt:
        torch.save(
            net.state_dict(),
            os.path.join(args.output_dir, f'{args.model}_INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)
Exemplo n.º 8
0
        optimizer = optim.SGD(net_parallel.parameters(),
                              lr=base_lr,
                              momentum=0.9,
                              weight_decay=0.00004)
        iter_num = 0
        while True:
            for i_batch, sampled_batch in enumerate(dataloader):
                volume_batch, label_batch = sampled_batch[
                    'image'], sampled_batch['label']
                volume_batch, label_batch = volume_batch.cuda(
                ), label_batch.cuda()
                output = net_parallel(volume_batch)

                output = F.sigmoid(output)
                loss = dice_loss(output, label_batch)

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()

                iter_num = iter_num + 1
                if iter_num % 10 == 0:
                    print('iteration %d : loss : %f' % (iter_num, loss.item()))
                if iter_num % 5000 == 0:
                    torch.save(
                        net.state_dict(),
                        os.path.join(
                            snapshot_path, snapshot_prefix + '_iteration_' +
                            str(iter_num) + '.pth'))
                if iter_num >= max_iterations:
Exemplo n.º 9
0
    def forward(self, batch_input, task=None):
        batch_output = {}

        # index = batch_input["idx"]
        self.stage = "finetune"

        views = batch_input["image"]
        device = views.device
        bs = views.size(0)
        self.batch_size = bs

        # road_map = batch_input["road"]

        final_features = self.image_network(views.flatten(0, 1))

        _, c, h, w = final_features.shape
        views = final_features.view(bs, 6, c, h, w)

        batch_output["loss"] = 0

        # print("views", views.shape)

        if self.gen_roadmap or self.gen_semantic_map or self.gen_object_map or (
                self.detect_objects and "decoder" in self.blobs_strategy):
            fusion = self.fuse(views)
            # print("fusion", fusion.shape)

        if self.gen_roadmap or self.gen_semantic_map or self.gen_object_map:
            if "det" in self.model_type:
                # print("here")
                # if self.dense_fuse:

                #     fusion = self.reshape(fusion).view(-1,32,16,16)

                # print("reshape", fusion.shape)

                mapped_image = self.decoder_network(fusion)  #fusion)

                # if self.training:
                # print(mapped_image.shape)
                if self.gen_roadmap:
                    batch_output["road_map"] = F.sigmoid(mapped_image)
                else:
                    batch_output["sem_map"] = F.softmax(mapped_image, dim=1)

                if self.loss_type == "dice":
                    if self.args.gen_road_map:
                        batch_output["recon_loss"] = dice_loss(
                            batch_input["road"].type(torch.LongTensor),
                            mapped_image)
                    else:
                        batch_output["recon_loss"] = dice_loss(
                            batch_input["sem_map"].max(dim=1)[1].type(
                                torch.LongTensor), mapped_image)

                elif self.loss_type == 'bce':
                    if self.gen_roadmap:
                        batch_output["recon_loss"] = self.criterion(
                            mapped_image, batch_input["road"])
                    else:
                        batch_output["recon_loss"] = self.criterion(
                            mapped_image, batch_input["sem_map"].max(dim=1)[1])

                else:
                    if self.args.gen_road_map:
                        batch_output["recon_loss"] = self.criterion(
                            batch_output["road_map"], batch_input["road"])
                    else:
                        batch_output["recon_loss"] = self.criterion(
                            batch_output["sem_map"], batch_input["sem_map"])

                if self.gen_roadmap:
                    batch_output["ts_road_map"] = compute_ts_road_map(
                        batch_output["road_map"], batch_input["road"])
                else:
                    batch_output["ts_road_map"] = (batch_output["sem_map"].max(
                        dim=1)[1] == batch_input["sem_map"].max(
                            dim=1)[1]).float().mean()

                batch_output["ts"] = batch_output["ts_road_map"]
                batch_output["loss"] += batch_output["recon_loss"]
                # else:
                #     return nn.Sigmoid(mapped_image)

            else:

                if self.conv_fuse:
                    fusion = self.avg_pool_refine(
                        self.avg_pool(fusion).view(-1, self.d_model))

                mu_logvar = self.z_project(fusion).view(bs, 2, self.latent_dim)

                mu = mu_logvar[:, 0, :]
                logvar = mu_logvar[:, 1, :]

                z = self.reparameterize(mu, logvar)

                z = self.z_refine(z)

                z = self.z_reshape(z).view(bs, 32, 16, 16)

                generated_image = self.decoder_network(z)

                if self.gen_roadmap:
                    batch_output["road_map"] = nn.Sigmoid(generated_image)
                else:
                    batch_output["sem_map"] = nn.Softmax(generated_image,
                                                         dim=1)

                if self.loss_type == "dice":
                    batch_output["recon_loss"] = dice_loss(
                        batch_input["road"], batch_output["road_map"])
                elif self.loss_type == 'bce':
                    if self.gen_roadmap:
                        batch_output["recon_loss"] = self.criterion(
                            generated_image, batch_input["road"])
                    else:
                        batch_output["recon_loss"] = self.criterion(
                            generated_image,
                            batch_input["sem_map"].max(dim=1)[1])
                else:
                    batch_output["recon_loss"] = self.criterion(
                        batch_output["road_map"], batch_input["road"])

                if self.gen_roadmap:
                    batch_output["ts_road_map"] = compute_ts_road_map(
                        batch_output["road_map"], batch_input["road"])
                else:
                    batch_output["ts_road_map"] = (
                        batch_output["road_map"].max(
                            dim=1)[1] == batch_input["sem_map"].max(
                                dim=1)[1]).float().mean()

                batch_output["KLD_loss"] = -0.5 * torch.sum(1 + logvar -
                                                            mu.pow(2) -
                                                            logvar.exp())
                batch_output["ts"] = batch_output["ts_road_map"]
                batch_output["loss"] += batch_output[
                    "recon_loss"] + batch_output["KLD_loss"]

        if self.detect_objects:

            if "decoder" in self.blobs_strategy:
                if "var" in self.model_type:
                    batch_output = self.obj_detection_model(
                        z, batch_input, batch_output)
                else:
                    batch_output = self.obj_detection_model(
                        fusion, batch_input, batch_output, fusion)
            else:
                batch_output = self.obj_detection_model(
                    batch_input["image"], batch_input, batch_output)

        return batch_output
Exemplo n.º 10
0
    def forward(self, batch_input, task="none"):
        batch_output = {}

        # index = batch_input["idx"]
        self.stage = "finetune"

        views = batch_input["image"]
        device = views.device
        bs = views.size(0)
        self.batch_size = bs

        # road_map = batch_input["road"]

        gen_latent_features = self.image_network(views.flatten(0, 1))

        _, c, h, w = gen_latent_features.shape
        views = gen_latent_features.view(bs, 6, c, h, w)

        batch_output["loss"] = 0

        # print("views", views.shape)
        if "cond" in self.args.finetune_obj:
            z = torch.randn(bs, self.args.latent_dim).to(device)
            fusion = self.fuse(views, z)
        else:
            fusion = self.fuse(views)
        # print("fusion", fusion.shape)

        # print("here")
        # if self.dense_fuse:

        #     fusion = self.reshape(fusion).view(-1,32,16,16)

        # print("reshape", fusion.shape)

        gen_image = self.decoder_network(fusion)  #fusion)

        # real_disc_inp = batch_input["road"]
        # fake_disc_inp = gen_image.detach()

        # if "patch" in self.args.disc_type:
        #     b,c,h,w = real_disc_inp.shape
        #     # real_disc_inp = real_disc_inp.view(b,-1)
        #     # fake_disc_inp = fake_disc_inp.view(b,-1)
        #     zeros = torch.zeros(bs,1,16,16).to(device)
        #     ones = torch.ones(bs,1,16,16).to(device)

        # else:
        #     zeros = torch.zeros(bs,1).to(device)
        #     ones = torch.ones(bs,1).to(device)

        # real_disc_op = self.discriminator(real_disc_inp)
        # batch_output["real_dloss"] = self.criterion(real_disc_op,ones)

        # fake_disc_op = self.discriminator(fake_disc_inp)
        # batch_output["fake_dloss"] = self.criterion(fake_disc_op,zeros)
        # batch_output["Dloss"] = batch_output["real_dloss"] + batch_output["fake_dloss"]

        if self.args.gen_road_map:
            batch_output["road_map"] = F.sigmoid(gen_image)
        else:
            batch_output["road_map"] = F.softmax(gen_image, dim=1)

        if self.args.gen_road_map:
            batch_output["ts_road_map"] = compute_ts_road_map(
                batch_output["road_map"], batch_input["road"])
        else:
            batch_output["ts_road_map"] = (batch_output["road_map"].max(
                dim=1)[1] == batch_input["sem_map"].max(
                    dim=1)[1]).float().mean()

        batch_output["ts"] = batch_output["ts_road_map"]
        # batch_output["GDiscloss"] = self.criterion(fake_disc_op,ones)
        if self.args.road_map_loss == "dice":
            if self.args.gen_road_map:
                batch_output["GSupLoss"] = dice_loss(
                    batch_input["road"].type(torch.LongTensor), gen_image)
            else:
                batch_output["GSupLoss"] = dice_loss(
                    batch_input["sem_map"].max(dim=1)[1].type(
                        torch.LongTensor), gen_image)
            # batch_output["GSupLoss"] = dice_loss(batch_input["road"], batch_output["road_map"])
        else:
            if self.args.gen_road_map:
                batch_output["GSupLoss"] = self.criterion(
                    gen_image, batch_input["road"])
            else:
                batch_output["GSupLoss"] = self.criterion(
                    gen_image, batch_input["sem_map"].max(dim=1)[1])

        # else:
        #     batch_output["GSupLoss"] = self.criterion(batch_output["road_map"], batch_input["road"])

        # batch_output["GSupLoss"] = self.criterion(batch_output["road_map"],batch_input["road"])
        # batch_output["Gloss"] = batch_output["GDiscloss"] + batch_output["GSupLoss"]

        # batch_output["loss"] = batch_output["Dloss"] + batch_output["Gloss"]

        # if self.training:
        #     batch_output["recon_loss"] = self.criterion(mapped_image, road_map)
        #     batch_output["road_map"] = nn.Sigmoid(mapped_image)
        #     batch_output["ts_road_map"] = compute_ts_road_map(batch_output["road_map"],road_map)
        #     batch_output["loss"] += batch_output["recon_loss"]
        # else:
        #     return nn.Sigmoid(mapped_image)

        # if self.detect_objects:

        #     if "decoder" in self.blobs_strategy:
        #         if "var" in self.model_type:
        #             batch_output = self.obj_detection_model(z,batch_input,batch_output)
        #         else:
        #             batch_output = self.obj_detection_model(fusion,batch_input,batch_output,fusion)
        #     else:
        #         batch_output = self.obj_detection_model(batch_input["image"],batch_input,batch_output)

        return batch_output
Exemplo n.º 11
0
def main():
    parser = ArgumentParser()
    parser.add_argument('-d',
                        '--data_path',
                        dest='data_path',
                        type=str,
                        default=None,
                        help='path to the data')
    parser.add_argument('-e',
                        '--epochs',
                        dest='epochs',
                        default=20,
                        type=int,
                        help='number of epochs')
    parser.add_argument('-b',
                        '--batch_size',
                        dest='batch_size',
                        default=40,
                        type=int,
                        help='batch size')
    parser.add_argument('-s',
                        '--image_size',
                        dest='image_size',
                        default=256,
                        type=int,
                        help='input image size')
    parser.add_argument('-lr',
                        '--learning_rate',
                        dest='lr',
                        default=0.0001,
                        type=float,
                        help='learning rate')
    parser.add_argument('-wd',
                        '--weight_decay',
                        dest='weight_decay',
                        default=5e-4,
                        type=float,
                        help='weight decay')
    parser.add_argument('-lrs',
                        '--learning_rate_step',
                        dest='lr_step',
                        default=10,
                        type=int,
                        help='learning rate step')
    parser.add_argument('-lrg',
                        '--learning_rate_gamma',
                        dest='lr_gamma',
                        default=0.5,
                        type=float,
                        help='learning rate gamma')
    parser.add_argument('-m',
                        '--model',
                        dest='model',
                        default='unet',
                        choices=('unet', ))
    parser.add_argument('-w',
                        '--weight_bce',
                        default=0.5,
                        type=float,
                        help='weight BCE loss')
    parser.add_argument('-l',
                        '--load',
                        dest='load',
                        default=False,
                        help='load file model')
    parser.add_argument('-v',
                        '--val_split',
                        dest='val_split',
                        default=0.8,
                        help='train/val split')
    parser.add_argument('-o',
                        '--output_dir',
                        dest='output_dir',
                        default='/tmp/logs/',
                        help='dir to save log and models')
    args = parser.parse_args()
    #
    os.makedirs(args.output_dir, exist_ok=True)
    logger = get_logger(os.path.join(args.output_dir, 'train.log'))
    logger.info('Start training with params:')
    for arg, value in sorted(vars(args).items()):
        logger.info("Argument %s: %r", arg, value)
    #
    net = UNet(
    )  # TODO: to use move novel arch or/and more lightweight blocks (mobilenet) to enlarge the batch_size
    # TODO: img_size=256 is rather mediocre, try to optimize network for at least 512
    logger.info('Model type: {}'.format(net.__class__.__name__))
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    if args.load:
        net.load_state_dict(torch.load(args.load))
    net.to(device)
    # net = nn.DataParallel(net)

    optimizer = optim.Adam(net.parameters(),
                           lr=args.lr,
                           weight_decay=args.weight_decay)
    # TODO: loss experimentation, fight class imbalance, there're many ways you can tackle this challenge
    criterion = lambda x, y: (args.weight_bce * nn.BCELoss()(x, y),
                              (1. - args.weight_bce) * dice_loss(x, y))
    # TODO: you can always try on plateau scheduler as a default option
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.lr_step, gamma=args.lr_gamma) \
        if args.lr_step > 0 else None

    # dataset
    # TODO: to work on transformations a lot, look at albumentations package for inspiration
    train_transforms = Compose([
        Crop(min_size=1 - 1 / 3., min_ratio=1.0, max_ratio=1.0, p=0.5),
        Flip(p=0.05),
        Pad(max_size=0.6, p=0.25),
        Resize(size=(args.image_size, args.image_size), keep_aspect=True)
    ])
    # TODO: don't forget to work class imbalance and data cleansing
    val_transforms = Resize(size=(args.image_size, args.image_size))

    train_dataset = DetectionDataset(args.data_path,
                                     os.path.join(args.data_path,
                                                  'train_mask.json'),
                                     transforms=train_transforms)
    val_dataset = DetectionDataset(args.data_path,
                                   None,
                                   transforms=val_transforms)

    # split dataset into train/val, don't try to do this at home ;)
    train_size = int(len(train_dataset) * args.val_split)
    val_dataset.image_names = train_dataset.image_names[train_size:]
    val_dataset.mask_names = train_dataset.mask_names[train_size:]
    train_dataset.image_names = train_dataset.image_names[:train_size]
    train_dataset.mask_names = train_dataset.mask_names[:train_size]

    # TODO: always work with the data: cleaning, sampling
    train_dataloader = DataLoader(train_dataset,
                                  batch_size=args.batch_size,
                                  num_workers=8,
                                  shuffle=True,
                                  drop_last=True)
    val_dataloader = DataLoader(val_dataset,
                                batch_size=args.batch_size,
                                num_workers=4,
                                shuffle=False,
                                drop_last=False)
    logger.info('Length of train/val=%d/%d', len(train_dataset),
                len(val_dataset))
    logger.info('Number of batches of train/val=%d/%d', len(train_dataloader),
                len(val_dataloader))

    try:
        train(net,
              optimizer,
              criterion,
              scheduler,
              train_dataloader,
              val_dataloader,
              logger=logger,
              args=args,
              device=device)
    except KeyboardInterrupt:
        torch.save(net.state_dict(),
                   os.path.join(args.output_dir, 'INTERRUPTED.pth'))
        logger.info('Saved interrupt')
        sys.exit(0)