コード例 #1
0
def test(modelPath):
    # testDataset = CamVid()
    testDataset = cattleDataset()
    testLoader = DataLoader(testDataset,
                            batch_size=N,
                            shuffle=True,
                            num_workers=num_workers,
                            drop_last=True)

    eNet = ENet(C)
    eNet.load_state_dict(torch.load(modelPath))
    #eNet.eval()
    for e in range(4):
        batch_avg_EER = 0
        for batchID, batchData in enumerate(testLoader):
            inputs, labels = batchData['image'], batchData['semantic']
            inputs, labels = inputs.to(device), labels.to(device)

            with torch.no_grad():
                outputs = eNet(inputs.float())

            #plt.imshow(outputs.numpy()[0, 0,:,:])
            # outputs = eNet(inputs.float())
            '''
            Check the training process
            '''
            annotation = outputs[0, :, :, :].squeeze(0)
            # combine to one dimension
            Annot = annotation.data.max(0)[1].cpu().numpy()
            # print(Annot, np.max(Annot), np.min(Annot))
            showImage(Annot)
コード例 #2
0
ファイル: main_MIDL.py プロジェクト: wyk0517/SizeLoss_WSS
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    # Batch size for training MUST be 1 in weakly/semi supervised learning if we want to impose constraints.
    batch_size = 1
    batch_size_val = 1
    lr = 0.0005
    epoch = 1000

    root_dir = './ACDC-2D-All'
    model_dir = 'model'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=False)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    minVal = 97.9
    maxVal = 1722.6
    minSize = torch.FloatTensor(1)
    minSize.fill_(np.int64(minVal).item())
    maxSize = torch.FloatTensor(1)
    maxSize.fill_(np.int64(maxVal).item())

    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")
    num_classes = 2

    netG = ENet(1, num_classes)

    netG.apply(weights_init)
    softMax = nn.Softmax()
    Dice_loss = computeDiceOneHotBinary()

    modelName = 'WeaklySupervised_CE-2_b'

    print(' Model name: {}'.format(modelName))
    partial_ce = Partial_CE()
    mil_loss = MIL_Loss()
    size_loss = Size_Loss()

    if torch.cuda.is_available():
        netG.cuda()
        softMax.cuda()
        Dice_loss.cuda()

    optimizerG = torch.optim.Adam(netG.parameters(), lr=lr, betas=(0.5, 0.999))

    BestDice, BestEpoch = 0, 0

    dBAll = []
    Losses = []

    annotatedPixels = 0
    totalPixels = 0

    print(" ~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    print(' --------- Params: ---------')
    print(' - Lower bound: {}'.format(minVal))
    print(' - Upper bound: {}'.format(maxVal))
    for i in range(epoch):
        netG.train()
        lossVal = []
        lossVal1 = []

        totalImages = len(train_loader)
        for j, data in enumerate(train_loader):
            image, labels, weak_labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizerG.zero_grad()
            netG.zero_grad()

            MRI = to_var(image)
            Segmentation = to_var(labels)
            weakAnnotations = to_var(weak_labels)

            segmentation_prediction = netG(MRI)

            annotatedPixels = annotatedPixels + weak_labels.sum()
            totalPixels = totalPixels + weak_labels.shape[
                2] * weak_labels.shape[3]
            temperature = 0.1
            predClass_y = softMax(segmentation_prediction / temperature)
            Segmentation_planes = getOneHot_Encoded_Segmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # lossCE_numpy = partial_ce(segmentation_prediction, Segmentation_planes, weakAnnotations)
            lossCE_numpy = partial_ce(predClass_y, Segmentation_planes,
                                      weakAnnotations)

            # sizeLoss_val = size_loss(segmentation_prediction, Segmentation_planes, Variable(minSize), Variable(maxSize))
            sizeLoss_val = size_loss(predClass_y, Segmentation_planes,
                                     Variable(minSize), Variable(maxSize))

            # MIL_Loss_val = mil_loss(predClass_y, Segmentation_planes)

            # Dice loss (ONLY USED TO COMPUTE THE DICE. This DICE loss version does not work)
            DicesN, DicesB = Dice_loss(segmentation_prediction_ones,
                                       Segmentation_planes)
            DiceN = DicesToDice(DicesN)
            DiceB = DicesToDice(DicesB)

            Dice_score = (DiceB + DiceN) / 2

            # Choose between the different models
            # lossG = lossCE_numpy + MIL_Loss_val
            lossG = lossCE_numpy + sizeLoss_val
            # lossG = lossCE_numpy
            # lossG = sizeLoss_val

            lossG.backward(retain_graph=True)
            optimizerG.step()

            lossVal.append(lossG.data[0])
            lossVal1.append(lossCE_numpy.data[0])

            printProgressBar(
                j + 1,
                totalImages,
                prefix="[Training] Epoch: {} ".format(i),
                length=15,
                suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} ".format(
                    Dice_score.data[0], DiceB.data[0]))

        deepSupervision = False
        printProgressBar(
            totalImages,
            totalImages,
            done=
            f"[Training] Epoch: {i}, LossG: {np.mean(lossVal):.4f}, lossMSE: {np.mean(lossVal1):.4f}"
        )

        Losses.append(np.mean(lossVal))
        d1, sizeGT, sizePred = inference(netG, temperature, val_loader,
                                         batch_size, i, deepSupervision,
                                         modelName, minVal, maxVal)

        dBAll.append(d1)

        directory = 'Results/Statistics/MIDL/' + modelName
        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, modelName + '_Losses.npy'), Losses)
        np.save(os.path.join(directory, modelName + '_dBAll.npy'), dBAll)

        currentDice = d1

        print(" [VAL] DSC: (1): {:.4f} ".format(d1))
        # saveImagesSegmentation(netG, val_loader_save_imagesPng, batch_size_val_savePng, i, 'test', False)

        if currentDice > BestDice:
            BestDice = currentDice
            if not os.path.exists(model_dir):
                os.makedirs(model_dir)
            torch.save(netG,
                       os.path.join(model_dir, "Best_" + modelName + ".pkl"))

        if i % (BestEpoch + 10):
            for param_group in optimizerG.param_groups:
                param_group['lr'] = lr
コード例 #3
0
                              HALF_CROP=int(args.crop_shift))
        for x in phase_range
    }
    dataloaders = {
        x: torch.utils.data.DataLoader(image_datasets[x],
                                       batch_size=BATCH_SIZE,
                                       shuffle=True,
                                       num_workers=8)
        for x in phase_range
    }
    dataset_sizes = {x: len(image_datasets[x]) for x in phase_range}

    # Model
    NAME2MODEL = {
        "ENet":
        ENet(in_channels=1, num_classes=args.num_classes),
        "UNet":
        UNet(in_channels=1, out_channels=args.num_classes),
        "Original":
        UNet_Original(in_channels=1, out_channels=args.num_classes),
        "Original_with_BatchNorm":
        UNet_Original_with_BatchNorm(in_channels=1,
                                     out_channels=args.num_classes),
        "Baseline":
        UNet_Baseline(in_channels=1, out_channels=args.num_classes),
        "Dilated":
        UNet_Dilated(in_channels=1, out_channels=args.num_classes),
        "ProgressiveDilated":
        UNet_ProgressiveDilated(in_channels=1, out_channels=args.num_classes),
    }
    model = NAME2MODEL[args.model]
コード例 #4
0
def run():
    num_classes = 3
    image_shape = IMG_SIZE
    data_dir = TRAINING_DIR
    runs_dir = './runs'
    epochs = 2
    batch_size = 1
    learning_rate = 1e-5

    net_input = tf.placeholder(tf.float32,
                               shape=[None, image_shape[0], image_shape[1], 3],
                               name="net_input")
    net_output = tf.placeholder(
        tf.float32,
        shape=[None, image_shape[0], image_shape[1], num_classes],
        name="net_output")

    logits, probabilities = ENet(net_input,
                                 num_classes,
                                 batch_size=batch_size,
                                 is_training=True,
                                 reuse=None,
                                 num_initial_blocks=1,
                                 stage_two_repeat=2,
                                 skip_connections=False)

    network = tf.reshape(probabilities, (-1, num_classes), name='logits')
    loss = custom_loss(network, net_output)

    # annotations = tf.reshape(annotations, shape=[batch_size, image_shape[0], image_shape[1]])
    # annotations_ohe = tf.one_hot(annotations, num_classes, axis=-1)

    opt = tf.train.AdamOptimizer(1e-4).minimize(
        loss, var_list=[var for var in tf.trainable_variables()])

    with tf.Session() as sess:

        # Create function to get batches
        get_batches_fn = helper.gen_batch_function(os.path.join(data_dir),
                                                   RGB_DIR, SEG_DIR,
                                                   image_shape)

        init_op = tf.global_variables_initializer()

        saver = tf.train.Saver()

        # Runs training
        sess.run(init_op)
        train_nn(sess, epochs, batch_size, get_batches_fn, opt, loss,
                 net_input, net_output, learning_rate)

        # Save the trained model
        today = datetime.datetime.now().strftime("%Y-%m-%d-%H%M")
        save_dir = os.path.join(SAVE_MODEL_DIR, today)
        helper.save_model(sess, net_input, network, save_dir)

        print("SavedModel saved at {}".format(save_dir))

        test_dir = TEST_DIR
        helper.save_inference_samples(runs_dir, test_dir, sess, image_shape,
                                      network, net_input)
コード例 #5
0
def train(modelPath):
    # trainDataset = CamVid()
    trainDataset = cattleDataset()
    trainLoader = DataLoader(trainDataset,
                             batch_size=N,
                             shuffle=True,
                             num_workers=num_workers,
                             drop_last=True)
    print(trainLoader)
    # instantiation the ENet
    efficientNet = ENet(C).to(device)
    if restore:
        print("++" * 10, '\n')
        pretrained_dict = torch.load(modelPath)
        model_dict = efficientNet.state_dict()
        ## remove keys DONNOT belong to model_dict
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        ## update current keys of model_dict
        #print(model_dict)
        model_dict.update(pretrained_dict)
        ## load model
        efficientNet.load_state_dict(model_dict)
        print("[INFO:] Load pretrained model successfully!!!\n")

    criterion = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(efficientNet.parameters(), lr=lr, momentum=0.9)
    os.makedirs(checkpointDir, exist_ok=True)
    efficientNet.train()
    iteration = 0
    for e in range(epochs):
        totalLoss = 0
        for batchID, batchData in enumerate(trainLoader):
            # get the inputs; data is a list of [image, semantic]
            inputs, labels = batchData['image'], batchData['semantic']
            inputs, labels = inputs.to(device), labels.to(device)
            # zero the parameter gradients
            optimizer.zero_grad()
            # forward + backward + optimize
            outputs = efficientNet(inputs.float())

            #########################################
            #annotation = outputs[0,:,:,:].squeeze(0)
            ## combine to one dimension
            #Annot = annotation.data.max(0)[1].cpu().numpy()
            #print(Annot, np.max(Annot), np.min(Annot))
            #plt.imshow(Annot)
            #input()
            #########################################

            loss = criterion(outputs, labels.long())
            loss.backward()
            torch.nn.utils.clip_grad_norm_(efficientNet.parameters(), 3.0)
            optimizer.step()

            # print statistis
            totalLoss += loss.item()
            iteration += 1

            if (batchID + 1) % logInterval == 0:
                mesg = "{0}\tEpoch:{1}[{2}/{3}],Iteration:{4}\tLoss:{5:.4f}\tTLoss:{6:.4f}\t\n".format(
                    time.ctime(), e + 1, batchID + 1,
                    len(trainDataset) // N, iteration, loss,
                    totalLoss / (batchID + 1))
                print(mesg)
                if logFile is not None:
                    with open(logFile, 'a') as f:
                        f.write(mesg)
    #save model
    # Checkout to eval to save model
    # efficientNet.eval().cpu()
    print("====not eval() before save====")
    efficientNet.cpu()
    save_model_filename = "final_epoch_" + str(e + 1) + "_batch_id_" + str(
        batchID + 1) + ".pth"
    save_model_path = os.path.join(checkpointDir, save_model_filename)
    torch.save(efficientNet.state_dict(), save_model_path)

    print("\nDone, trained model saved at", save_model_path)