Ejemplo n.º 1
0
def runTraining(args):
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = args.batch_size
    batch_size_val = 1
    batch_size_val_save = 1
    lr = args.lr

    epoch = args.epochs
    root_dir = './DataSet_Challenge/Val_1'
    model_dir = 'model'
    
    print(' Dataset: {} '.format(root_dir))

    transform = transforms.Compose([
        transforms.ToTensor()
    ])

    mask_transform = transforms.Compose([
        transforms.ToTensor()
    ])

    train_set = medicalDataLoader.MedicalImageDataset('train',
                                                      root_dir,
                                                      transform=transform,
                                                      mask_transform=mask_transform,
                                                      augment=True,
                                                      equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset('val',
                                                    root_dir,
                                                    transform=transform,
                                                    mask_transform=mask_transform,
                                                    equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)
                                                   
    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=4,
                                        shuffle=False)

                                                                    
    # Initialize
    print("~~~~~~~~~~~ Creating the DAF Stacked model ~~~~~~~~~~")
    net = DAF_stack()
    print(" Model Name: {}".format(args.modelName))
    print(" Model ot create: DAF_Stacked")

    net.apply(weights_init)

    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_loss = computeDiceOneHot()
    mseLoss = nn.MSELoss()

    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_loss.cuda()

    optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False)

    BestDice, BestEpoch = 0, 0
    BestDice3D = [0,0,0,0]

    d1Val = []
    d2Val = []
    d3Val = []
    d4Val = []

    d1Val_3D = []
    d2Val_3D = []
    d3Val_3D = []
    d4Val_3D = []

    d1Val_3D_std = []
    d2Val_3D_std = []
    d3Val_3D_std = []
    d4Val_3D_std = []

    Losses = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        net.train()
        lossVal = []

        totalImages = len(train_loader)
       
        for j, data in enumerate(train_loader):
            image, labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizer.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)

            ################### Train ###################
            net.zero_grad()

            # Network outputs
            semVector_1_1, \
            semVector_2_1, \
            semVector_1_2, \
            semVector_2_2, \
            semVector_1_3, \
            semVector_2_3, \
            semVector_1_4, \
            semVector_2_4, \
            inp_enc0, \
            inp_enc1, \
            inp_enc2, \
            inp_enc3, \
            inp_enc4, \
            inp_enc5, \
            inp_enc6, \
            inp_enc7, \
            out_enc0, \
            out_enc1, \
            out_enc2, \
            out_enc3, \
            out_enc4, \
            out_enc5, \
            out_enc6, \
            out_enc7, \
            outputs0, \
            outputs1, \
            outputs2, \
            outputs3, \
            outputs0_2, \
            outputs1_2, \
            outputs2_2, \
            outputs3_2 = net(MRI)

            segmentation_prediction = (outputs0 + outputs1 + outputs2 + outputs3 + outputs0_2 + outputs1_2 + outputs2_2 + outputs3_2) / 8
            predClass_y = softMax(segmentation_prediction)

            Segmentation_planes = getOneHotSegmentation(Segmentation)

            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            # Cross-entropy loss
            loss0 = CE_loss(outputs0, Segmentation_class)
            loss1 = CE_loss(outputs1, Segmentation_class)
            loss2 = CE_loss(outputs2, Segmentation_class)
            loss3 = CE_loss(outputs3, Segmentation_class)
            loss0_2 = CE_loss(outputs0_2, Segmentation_class)
            loss1_2 = CE_loss(outputs1_2, Segmentation_class)
            loss2_2 = CE_loss(outputs2_2, Segmentation_class)
            loss3_2 = CE_loss(outputs3_2, Segmentation_class)

            lossSemantic1 = mseLoss(semVector_1_1, semVector_2_1)
            lossSemantic2 = mseLoss(semVector_1_2, semVector_2_2)
            lossSemantic3 = mseLoss(semVector_1_3, semVector_2_3)
            lossSemantic4 = mseLoss(semVector_1_4, semVector_2_4)

            lossRec0 = mseLoss(inp_enc0, out_enc0)
            lossRec1 = mseLoss(inp_enc1, out_enc1)
            lossRec2 = mseLoss(inp_enc2, out_enc2)
            lossRec3 = mseLoss(inp_enc3, out_enc3)
            lossRec4 = mseLoss(inp_enc4, out_enc4)
            lossRec5 = mseLoss(inp_enc5, out_enc5)
            lossRec6 = mseLoss(inp_enc6, out_enc6)
            lossRec7 = mseLoss(inp_enc7, out_enc7)

            lossG = loss0 + loss1 + loss2 + loss3 + loss0_2 + loss1_2 + loss2_2 + loss3_2 + 0.25 * (
            lossSemantic1 + lossSemantic2 + lossSemantic3 + lossSemantic4) \
            + 0.1 * (lossRec0 + lossRec1 + lossRec2 + lossRec3 + lossRec4 + lossRec5 + lossRec6 + lossRec7)  # CE_lossG

            # Compute the DSC
            DicesN, DicesB, DicesW, DicesT, DicesZ = Dice_loss(segmentation_prediction_ones, Segmentation_planes)

            DiceB = DicesToDice(DicesB)
            DiceW = DicesToDice(DicesW)
            DiceT = DicesToDice(DicesT)
            DiceZ = DicesToDice(DicesZ)

            Dice_score = (DiceB + DiceW + DiceT+ DiceZ) / 4

            lossG.backward()
            optimizer.step()
            
            lossVal.append(lossG.cpu().data.numpy())

            printProgressBar(j + 1, totalImages,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f}, Dice4: {:.4f} ".format(
                                 Dice_score.cpu().data.numpy(),
                                 DiceB.data.cpu().data.numpy(),
                                 DiceW.data.cpu().data.numpy(),
                                 DiceT.data.cpu().data.numpy(),
                                 DiceZ.data.cpu().data.numpy(),))

      
        printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f}".format(i,np.mean(lossVal)))
       
        # Save statistics
        modelName = args.modelName
        directory = 'Results/Statistics/' + modelName
        
        Losses.append(np.mean(lossVal))
        
        d1,d2,d3,d4 = inference(net, val_loader)
        
        d1Val.append(d1)
        d2Val.append(d2)
        d3Val.append(d3)
        d4Val.append(d4)

        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, 'Losses.npy'), Losses)
        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd2Val.npy'), d2Val)
        np.save(os.path.join(directory, 'd3Val.npy'), d3Val)

        currentDice = (d1+d2+d3+d4)/4

        print("[val] DSC: (1): {:.4f} (2): {:.4f}  (3): {:.4f} (4): {:.4f}".format(d1,d2,d3,d4)) # MRI

        currentDice = currentDice.data.numpy()

        # Evaluate on 3D
        saveImages_for3D(net, val_loader_save_images, batch_size_val_save, 1000, modelName, False, False)
        reconstruct3D(modelName, 1000, isBest=False)
        DSC_3D = evaluate3D(modelName)

        mean_DSC3D = np.mean(DSC_3D, 0)
        std_DSC3D = np.std(DSC_3D,0)

        d1Val_3D.append(mean_DSC3D[0])
        d2Val_3D.append(mean_DSC3D[1])
        d3Val_3D.append(mean_DSC3D[2])
        d4Val_3D.append(mean_DSC3D[3])
        d1Val_3D_std.append(std_DSC3D[0])
        d2Val_3D_std.append(std_DSC3D[1])
        d3Val_3D_std.append(std_DSC3D[2])
        d4Val_3D_std.append(std_DSC3D[3])

        np.save(os.path.join(directory, 'd0Val_3D.npy'), d1Val_3D)
        np.save(os.path.join(directory, 'd1Val_3D.npy'), d2Val_3D)
        np.save(os.path.join(directory, 'd2Val_3D.npy'), d3Val_3D)
        np.save(os.path.join(directory, 'd3Val_3D.npy'), d4Val_3D)
        
        np.save(os.path.join(directory, 'd0Val_3D_std.npy'), d1Val_3D_std)
        np.save(os.path.join(directory, 'd1Val_3D_std.npy'), d2Val_3D_std)
        np.save(os.path.join(directory, 'd2Val_3D_std.npy'), d3Val_3D_std)
        np.save(os.path.join(directory, 'd3Val_3D_std.npy'), d4Val_3D_std)


        if currentDice > BestDice:
            BestDice = currentDice

            BestEpoch = i
            
            if currentDice > 0.40:

                if np.mean(mean_DSC3D)>np.mean(BestDice3D):
                    BestDice3D = mean_DSC3D

                print("###    In 3D -----> MEAN: {}, Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f}   ###".format(np.mean(mean_DSC3D),mean_DSC3D[0], mean_DSC3D[1], mean_DSC3D[2], mean_DSC3D[3]))

                print("~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~")
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(net.state_dict(), os.path.join(model_dir, "Best_" + modelName + ".pth"),pickle_module=dill)
                reconstruct3D(modelName, 1000, isBest=True)

        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f}   ###".format(BestDice, BestEpoch, d1,d2,d3,d4))
        print("###    Best Dice in 3D: {:.4f} with Dice(1): {:.4f} Dice(2): {:.4f} Dice(3): {:.4f} Dice(4): {:.4f} ###".format(np.mean(BestDice3D),BestDice3D[0], BestDice3D[1], BestDice3D[2], BestDice3D[3] ))
        print("###                                                       ###")

        if i % (BestEpoch + 50) == 0:
            for param_group in optimizer.param_groups:
                lr = lr*0.5
                param_group['lr'] = lr
                print(' ----------  New learning Rate: {}'.format(lr))
Ejemplo n.º 2
0
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = 4
    batch_size_val = 1
    batch_size_val_save = 1
    batch_size_val_savePng = 4
    lr = 0.0001
    epoch = 1000
    root_dir = '../DataSet/Bladder_Aug'
    modelName = 'UNetG_Dilated_Progressive'
    model_dir = 'model'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)

    val_loader_save_imagesPng = DataLoader(val_set,
                                           batch_size=batch_size_val_savePng,
                                           num_workers=5,
                                           shuffle=False)
    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")
    num_classes = 4

    initial_kernels = 32

    # Load network
    netG = UNetG_Dilated_Progressive(1, initial_kernels, num_classes)
    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_loss = computeDiceOneHot()

    if torch.cuda.is_available():
        netG.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_loss.cuda()
    '''try:
        netG = torch.load('./model/Best_UNetG_Dilated_Progressive_Stride_Residual_ChannelsFirst32.pkl')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass'''

    optimizerG = Adam(netG.parameters(),
                      lr=lr,
                      betas=(0.9, 0.99),
                      amsgrad=False)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizerG,
                                                           mode='max',
                                                           patience=4,
                                                           verbose=True,
                                                           factor=10**-0.5)

    BestDice, BestEpoch = 0, 0

    d1Train = []
    d2Train = []
    d3Train = []
    d1Val = []
    d2Val = []
    d3Val = []

    Losses = []
    Losses1 = []
    Losses05 = []
    Losses025 = []
    Losses0125 = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        netG.train()
        lossVal = []
        lossValD = []
        lossVal1 = []
        lossVal05 = []
        lossVal025 = []
        lossVal0125 = []

        d1TrainTemp = []
        d2TrainTemp = []
        d3TrainTemp = []

        timesAll = []
        success = 0
        totalImages = len(train_loader)

        for j, data in enumerate(train_loader):
            image, labels, img_names = data

            # prevent batchnorm error for batch of size 1
            if image.size(0) != batch_size:
                continue

            optimizerG.zero_grad()
            MRI = to_var(image)
            Segmentation = to_var(labels)

            target_dice = to_var(torch.ones(1))

            ################### Train ###################
            netG.zero_grad()

            deepSupervision = False
            multiTask = False

            start_time = time.time()
            if deepSupervision == False and multiTask == False:
                # No deep supervision
                segmentation_prediction = netG(MRI)
            else:
                # Deep supervision
                if deepSupervision == True:
                    segmentation_prediction, seg_3, seg_2, seg_1 = netG(MRI)
                else:
                    segmentation_prediction, reg_output = netG(MRI)
                    # Regression
                    feats = getValuesRegression(labels)

                    feats_t = torch.from_numpy(feats).float()
                    featsVar = to_var(feats_t)

                    MSE_loss_val = MSE_loss(reg_output, featsVar)

            predClass_y = softMax(segmentation_prediction)

            spentTime = time.time() - start_time

            timesAll.append(spentTime / batch_size)

            Segmentation_planes = getOneHotSegmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            # No deep supervision
            CE_lossG = CE_loss(segmentation_prediction, Segmentation_class)
            if deepSupervision == True:

                imageLabels_05 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 2)
                imageLabels_025 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 4)
                imageLabels_0125 = resizeTensorMaskInSingleImage(
                    Segmentation_class, 8)

                CE_lossG_3 = CE_loss(seg_3, imageLabels_05)
                CE_lossG_2 = CE_loss(seg_2, imageLabels_025)
                CE_lossG_1 = CE_loss(seg_1, imageLabels_0125)
            '''weight = torch.ones(4).cuda() # Num classes
            weight[0] = 0.2
            weight[1] = 0.2
            weight[2] = 1
            weight[3] = 1
            
            CE_loss.weight = weight'''

            # Dice loss
            DicesN, DicesB, DicesW, DicesT = Dice_loss(
                segmentation_prediction_ones, Segmentation_planes)
            DiceN = DicesToDice(DicesN)
            DiceB = DicesToDice(DicesB)
            DiceW = DicesToDice(DicesW)
            DiceT = DicesToDice(DicesT)

            Dice_score = (DiceB + DiceW + DiceT) / 3

            if deepSupervision == False and multiTask == False:
                lossG = CE_lossG
            else:
                # Deep supervision
                if deepSupervision == True:
                    lossG = CE_lossG + 0.25 * CE_lossG_3 + 0.1 * CE_lossG_2 + 0.1 * CE_lossG_1
                else:
                    lossG = CE_lossG + 0.000001 * MSE_loss_val

            lossG.backward()
            optimizerG.step()

            lossVal.append(lossG.data[0])
            lossVal1.append(CE_lossG.data[0])

            if deepSupervision == True:
                lossVal05.append(CE_lossG_3.data[0])
                lossVal025.append(CE_lossG_2.data[0])
                lossVal0125.append(CE_lossG_1.data[0])

            printProgressBar(
                j + 1,
                totalImages,
                prefix="[Training] Epoch: {} ".format(i),
                length=15,
                suffix=
                " Mean Dice: {:.4f}, Dice1: {:.4f} , Dice2: {:.4f}, , Dice3: {:.4f} "
                .format(Dice_score.data[0], DiceB.data[0], DiceW.data[0],
                        DiceT.data[0]))

        if deepSupervision == False:
            '''printProgressBar(totalImages, totalImages,
                             done="[Training] Epoch: {}, LossG: {:.4f},".format(i,np.mean(lossVal),np.mean(lossVal1)))'''
            printProgressBar(
                totalImages,
                totalImages,
                done="[Training] Epoch: {}, LossG: {:.4f}, lossMSE: {:.4f}".
                format(i, np.mean(lossVal), np.mean(lossVal1)))
        else:
            printProgressBar(
                totalImages,
                totalImages,
                done=
                "[Training] Epoch: {}, LossG: {:.4f}, Loss4: {:.4f}, Loss3: {:.4f}, Loss2: {:.4f}, Loss1: {:.4f}"
                .format(i, np.mean(lossVal), np.mean(lossVal1),
                        np.mean(lossVal05), np.mean(lossVal025),
                        np.mean(lossVal0125)))

        Losses.append(np.mean(lossVal))

        d1, d2, d3 = inference(netG, val_loader, batch_size, i,
                               deepSupervision)

        d1Val.append(d1)
        d2Val.append(d2)
        d3Val.append(d3)

        d1Train.append(np.mean(d1TrainTemp).data[0])
        d2Train.append(np.mean(d2TrainTemp).data[0])
        d3Train.append(np.mean(d3TrainTemp).data[0])

        mainPath = '../Results/Statistics/' + modelName

        directory = mainPath
        if not os.path.exists(directory):
            os.makedirs(directory)

        ###### Save statistics  ######
        np.save(os.path.join(directory, 'Losses.npy'), Losses)

        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd2Val.npy'), d2Val)
        np.save(os.path.join(directory, 'd3Val.npy'), d3Val)

        np.save(os.path.join(directory, 'd1Train.npy'), d1Train)
        np.save(os.path.join(directory, 'd2Train.npy'), d2Train)
        np.save(os.path.join(directory, 'd3Train.npy'), d3Train)

        currentDice = (d1 + d2 + d3) / 3

        # How many slices with/without tumor correctly classified
        print("[val] DSC: (1): {:.4f} (2): {:.4f}  (3): {:.4f} ".format(
            d1, d2, d3))

        if currentDice > BestDice:
            BestDice = currentDice
            BestDiceT = d1
            BestEpoch = i
            if currentDice > 0.7:
                print(
                    "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
                )
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(
                    netG, os.path.join(model_dir,
                                       "Best_" + modelName + ".pkl"))

                # Save images
                saveImages(netG, val_loader_save_images, batch_size_val_save,
                           i, modelName, deepSupervision)
                saveImagesAsMatlab(netG, val_loader_save_images,
                                   batch_size_val_save, i, modelName,
                                   deepSupervision)

        print("###                                                       ###")
        print("###    Best Dice: {:.4f} at epoch {} with DiceT: {:.4f}    ###".
              format(BestDice, BestEpoch, BestDiceT))
        print("###                                                       ###")

        # This is not as we did it in the MedPhys paper
        if i % (BestEpoch + 20):
            for param_group in optimizerG.param_groups:
                param_group['lr'] = lr / 2
Ejemplo n.º 3
0
    def train (self, targ, inp, mode = 'online', epochs = 500, par = None,
                out = None, Jrout = None, track = False):
        '''
            This is the main function of the model: is used to trained the system
            given a target and and input. Two training mode can be selected:
            (offline, online). The first uses the exact likelihood gradient (which
            is non local in time, thus non biologically-plausible), the second is
            the online approx. of the gradient as descrived in the article.

            Args:
                targ: numpy.array of shape (N, T), where N is the number of neurons
                      in the network and T is the time length of the sequence.

                inp : numpy.array of shape (N, T) that collects the input signal
                      to neurons.

            Keywords:
                mode : (default: online) The training mode to use, either 'offline'
                       or 'online'.

                epochs: (default: 500) The number of epochs of training.

                par   : (default: None) Optional different dictionary collecting
                        training parameters: {dv, alpha, alpha_rout, beta_ro, offT}.
                        If not provided defaults to the parameter dictionary of
                        the model.

                out   : (default: None) Output target trajectories, numpy.array
                        of shape (K, T), where K is the dimension of the output
                        trajectories. This parameter should be specified if either
                        Jrout != None or track is True.

                Jrout : (default: None) Pre-trained readout connection matrix.
                        If not provided, a novel matrix is built and trained
                        simultaneously with the recurrent connections training.
                        If Jrout is provided, the out parameter should be specified
                        as it is needed to compute output error.

                track : (default: None) Flag to signal whether to track the evolution
                        of output MSE over training epochs. If track is True then
                        the out parameters should be specified as it is needed to
                        compute output error.

        '''
        assert (targ.shape == inp.shape);

        par = self.par if par is None else par;

        dv = par['dv'];

        itau_m = self.itau_m;
        itau_s = self.itau_s;

        sigm = self._sigm;

        alpha = par['alpha'];
        alpha_rout = par['alpha_rout'];
        beta_ro = par['beta_ro'];
        offT = par['offT'];

        self.S [:, 0] = targ [:, 0].copy ();
        self.S_hat [:, 0] = self.S [:, 0] * itau_s;

        Tmax = np.shape (targ) [-1];
        dH = np.zeros ((self.N, self.T));

        track = np.zeros (epochs) if track else None;

        opt_rec = Adam (alpha = alpha, drop = 0.9, drop_time = 100 * Tmax if mode == 'online' else 100);

        if Jrout is None:
            S_rout = ut.sfilter (targ, itau = beta_ro);
            J_rout = np.random.normal (0., 0.1, size = (out.shape[0], self.N));
            opt = Adam (alpha = alpha_rout, drop = 0.9, drop_time = 20 * Tmax if mode == 'online' else 20);

        else:
            J_rout = Jrout;

        for epoch in trange (epochs, leave = False, desc = 'Training {}'.format (mode)):
            if Jrout is None:
                # Here we train the readout
                dJrout = (out - J_rout @ S_rout) @ S_rout.T;
                J_rout = opt.step (J_rout, dJrout);

            # Here we train the network
            for t in range (Tmax - 1):
                self.S_hat [:, t] = self.S_hat [:, t - 1] * itau_s + targ [:, t] * (1. - itau_s) if t > 0 else self.S_hat [:, 0];

                self.H [:, t + 1] = self.H [:, t] * (1. - itau_m) + itau_m * (self.J @ self.S_hat [:, t] + inp [:, t] + self.h [:, t])\
                                                                  + self.Jreset @ targ [:, t];

                dH [:, t + 1] = dH [:, t]  * (1. - itau_m) + itau_m * self.S_hat [:, t];

                if mode == 'online':
                    dJ = np.outer (targ [:, t + 1] - sigm (self.H [:, t + 1], dv = dv), dH [:, t + 1]);
                    self.J = opt_rec.step (self.J, dJ);

                    np.fill_diagonal (self.J, 0.);

            if mode == 'offline':
                dJ = (targ - sigm (self.H, dv = dv)) @ dH.T;
                self.J = opt_rec.step (self.J, dJ);

                np.fill_diagonal (self.J, 0.);

            # Here we track MSE
            if track is not None:
                S_gen = self.compute (inp, init = np.zeros (self.N));
                track [epoch] = np.mean ((out - J_rout @ ut.sfilter (S_gen,
                                                itau = beta_ro))[:, offT:]**2.);

        return (J_rout, track) if Jrout is None else track;
Ejemplo n.º 4
0
def main():
    # generate data and translate labels
    train_features, train_targets = generate_all_datapoints_and_labels()
    test_features, test_targets = generate_all_datapoints_and_labels()
    train_labels, test_labels = convert_labels(train_targets), convert_labels(test_targets)


    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('Model: Linear + ReLU + Linear +ReLU + Linear + ReLU + Linear + Tanh')
    print('Loss: MSE')
    print('Optimizer: SGD')
    print('*************************************************************************')
    print('Training')
    print('*************************************************************************')
    # build network, loss and optimizer for Model 1
    my_model_design_1=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(),
                       Linear(25,25), ReLU(),Linear(25,2),Tanh()]
    my_model_1=Sequential(my_model_design_1)
    optimizer_1=SGD(my_model_1,lr=1e-3)
    criterion_1=LossMSE()

    # train Model 1
    batch_size=1
    for epoch in range(50):
        temp_train_loss_sum=0.
        temp_test_loss_sum=0.
        num_train_correct=0
        num_test_correct=0
        
        # trained in batch-fashion: here batch size = 1
        for temp_batch in range(0,len(train_features), batch_size):
            temp_train_features=train_features.narrow(0, temp_batch, batch_size)  
            temp_train_labels=train_labels.narrow(0, temp_batch, batch_size)  
            
            for i in range(batch_size):
                # clean parameter gradient before each batch
                optimizer_1.zero_grad()  
                temp_train_feature=temp_train_features[i]
                temp_train_label=temp_train_labels[i]
                
                # forward pass to compute loss
                temp_train_pred=my_model_1.forward(temp_train_feature)
                temp_train_loss=criterion_1.forward(temp_train_pred,temp_train_label)
                temp_train_loss_sum+=temp_train_loss
                
                _, temp_train_pred_cat=torch.max(temp_train_pred,0)
                _, temp_train_label_cat=torch.max(temp_train_label,0)

                
                if temp_train_pred_cat==temp_train_label_cat:
                    num_train_correct+=1
  
                # calculate gradient according to loss gradient
                temp_train_loss_grad=criterion_1.backward(temp_train_pred,temp_train_label)
                # accumulate parameter gradient in each batch
                my_model_1.backward(temp_train_loss_grad)                       
            
            # update parameters by optimizer
            optimizer_1.step()
            
            
        # evaluate the current model on testing set
        # only forward pass is implemented
        for i_test in range(len(test_features)):
            temp_test_feature=test_features[i_test]
            temp_test_label=test_labels[i_test]

            temp_test_pred=my_model_1.forward(temp_test_feature)
            temp_test_loss=criterion_1.forward(temp_test_pred,temp_test_label)
            temp_test_loss_sum+=temp_test_loss

            
            _, temp_test_pred_cat=torch.max(temp_test_pred,0)
            _, temp_test_label_cat=torch.max(temp_test_label,0)

            if temp_test_pred_cat==temp_test_label_cat:
                num_test_correct+=1
            
            
        temp_train_loss_mean=temp_train_loss_sum/len(train_features)
        temp_test_loss_mean=temp_test_loss_sum/len(test_features)
        
        temp_train_accuracy=num_train_correct/len(train_features)
        temp_test_accuracy=num_test_correct/len(test_features)
        
        print("Epoch: {}/{}..".format(epoch+1, 50),
                      "Training Loss: {:.4f}..".format(temp_train_loss_mean),
                      "Training Accuracy: {:.4f}..".format(temp_train_accuracy), 
                      "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean),
                      "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy),  )
        
        
        
    # # visualize the classification performance of Model 1 on testing set
    test_pred_labels_1=[]
    for i in range(1000): 
        temp_test_feature=test_features[i]
        temp_test_label=test_labels[i]

        temp_test_pred=my_model_1.forward(temp_test_feature)

        _, temp_train_pred_cat=torch.max(temp_test_pred,0)
        if test_targets[i].int() == temp_train_pred_cat.int():
            test_pred_labels_1.append(int(test_targets[i]))
        else:
            test_pred_labels_1.append(2)
            
    fig,axes = plt.subplots(1,1,figsize=(6,6))
    axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_1)
    axes.set_title('Classification Performance of Model 1')
    plt.show()
                      
      
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('*************************************************************************')
    print('Model: Linear + ReLU + Linear + Dropout+ SeLU + Linear + Dropout + ReLU + Linear + Sigmoid')
    print('Loss: Cross Entropy')
    print('Optimizer: Adam')
    print('*************************************************************************')
    print('Training')
    print('*************************************************************************')
    
    # build network, loss function and optimizer for Model 2
    my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), SeLU(),
                       Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2),
                       Sigmoid()]
    my_model_2=Sequential(my_model_design_2)
    optimizer_2=Adam(my_model_2,lr=1e-3)
    criterion_2=CrossEntropy()

    # train Model 2
    batch_size=1
    epoch=0
    while(epoch<25):
        temp_train_loss_sum=0.
        temp_test_loss_sum=0.
        num_train_correct=0
        num_test_correct=0
        
        # trained in batch-fashion: here batch size = 1
        for temp_batch in range(0,len(train_features), batch_size):
            temp_train_features=train_features.narrow(0, temp_batch, batch_size)  
            temp_train_labels=train_labels.narrow(0, temp_batch, batch_size)  
            
            for i in range(batch_size):
                # clean parameter gradient before each batch
                optimizer_2.zero_grad()  
                temp_train_feature=temp_train_features[i]
                temp_train_label=temp_train_labels[i]
                
                # forward pass to compute loss
                temp_train_pred=my_model_2.forward(temp_train_feature)
                temp_train_loss=criterion_2.forward(temp_train_pred,temp_train_label)
                temp_train_loss_sum+=temp_train_loss
                
                _, temp_train_pred_cat=torch.max(temp_train_pred,0)
                _, temp_train_label_cat=torch.max(temp_train_label,0)

                
                if temp_train_pred_cat==temp_train_label_cat:
                    num_train_correct+=1
       
                
                # calculate gradient according to loss gradient
                temp_train_loss_grad=criterion_2.backward(temp_train_pred,temp_train_label)
                '''
                if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0):
                    continue
                '''
                # accumulate parameter gradient in each batch
                my_model_2.backward(temp_train_loss_grad)     
                
            # update parameters by optimizer
            optimizer_2.step()
            
        # evaluate the current model on testing set
        # only forward pass is implemented
        for i_test in range(len(test_features)):
            temp_test_feature=test_features[i_test]
            temp_test_label=test_labels[i_test]

            temp_test_pred=my_model_2.forward(temp_test_feature)
            temp_test_loss=criterion_2.forward(temp_test_pred,temp_test_label)
            temp_test_loss_sum+=temp_test_loss

            
            _, temp_test_pred_cat=torch.max(temp_test_pred,0)
            _, temp_test_label_cat=torch.max(temp_test_label,0)

            if temp_test_pred_cat==temp_test_label_cat:
                num_test_correct+=1
            
            
        temp_train_loss_mean=temp_train_loss_sum/len(train_features)
        temp_test_loss_mean=temp_test_loss_sum/len(test_features)
        
        temp_train_accuracy=num_train_correct/len(train_features)
        temp_test_accuracy=num_test_correct/len(test_features)
        
        # in case there is gradient explosion problem, initiliza model again and restart training
        # but the situation seldom happens
        if (not temp_train_loss_grad[0]>=0) and (not temp_train_loss_grad[0]<0):
            epoch=0
            my_model_design_2=[Linear(2,25), ReLU(), Linear(25,25), Dropout(p=0.5), ReLU(),
                       Linear(25,25),Dropout(p=0.5), ReLU(),Linear(25,2),Sigmoid()]
            my_model_2=Sequential(my_model_design_2)
            optimizer_2=Adam(my_model_2,lr=1e-3)
            criterion_2=CrossEntropy()
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('--------------------------------------------------------------------------------')
            print('Restart training because of gradient explosion')
            continue
        
        print("Epoch: {}/{}..".format(epoch+1, 25),
                      "Training Loss: {:.4f}..".format(temp_train_loss_mean),
                      "Training Accuracy: {:.4f}..".format(temp_train_accuracy), 
                      "Validation/Test Loss: {:.4f}..".format(temp_test_loss_mean),
                      "Validation/Test Accuracy: {:.4f}..".format(temp_test_accuracy),  )
        epoch+=1 
        
    # visualize the classification performance of Model 2 on testing set
    test_pred_labels_2=[]
    for i in range(1000): 
        temp_test_feature=test_features[i]
        temp_test_label=test_labels[i]

        temp_test_pred=my_model_2.forward(temp_test_feature)

        _, temp_train_pred_cat=torch.max(temp_test_pred,0)
        if test_targets[i].int() == temp_train_pred_cat.int():
            test_pred_labels_2.append(int(test_targets[i]))
        else:
            test_pred_labels_2.append(2)
            
    fig,axes = plt.subplots(1,1,figsize=(6,6))
    axes.scatter(test_features[:,0], test_features[:,1], c=test_pred_labels_2)
    axes.set_title('Classification Performance of Model 2')
    plt.show()
Ejemplo n.º 5
0
def runTraining():
    print('-' * 40)
    print('~~~~~~~~  Starting the training... ~~~~~~')
    print('-' * 40)

    batch_size = 4
    batch_size_val = 1
    batch_size_val_save = 1

    lr = 0.0001
    epoch = 200
    num_classes = 2
    initial_kernels = 32

    modelName = 'IVD_Net'

    img_names_ALL = []
    print('.' * 40)
    print(" ....Model name: {} ........".format(modelName))

    print(' - Num. classes: {}'.format(num_classes))
    print(' - Num. initial kernels: {}'.format(initial_kernels))
    print(' - Batch size: {}'.format(batch_size))
    print(' - Learning rate: {}'.format(lr))
    print(' - Num. epochs: {}'.format(epoch))

    print('.' * 40)
    root_dir = '../Data/Training_PngITK'
    model_dir = 'IVD_Net'

    transform = transforms.Compose([transforms.ToTensor()])

    mask_transform = transforms.Compose([transforms.ToTensor()])

    train_set = medicalDataLoader.MedicalImageDataset(
        'train',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        augment=False,
        equalize=False)

    train_loader = DataLoader(train_set,
                              batch_size=batch_size,
                              num_workers=5,
                              shuffle=True)

    val_set = medicalDataLoader.MedicalImageDataset(
        'val',
        root_dir,
        transform=transform,
        mask_transform=mask_transform,
        equalize=False)

    val_loader = DataLoader(val_set,
                            batch_size=batch_size_val,
                            num_workers=5,
                            shuffle=False)

    val_loader_save_images = DataLoader(val_set,
                                        batch_size=batch_size_val_save,
                                        num_workers=5,
                                        shuffle=False)
    # Initialize
    print("~~~~~~~~~~~ Creating the model ~~~~~~~~~~")

    net = IVD_Net_asym(1, num_classes, initial_kernels)

    # Initialize the weights
    net.apply(weights_init)

    softMax = nn.Softmax()
    CE_loss = nn.CrossEntropyLoss()
    Dice_ = computeDiceOneHotBinary()

    if torch.cuda.is_available():
        net.cuda()
        softMax.cuda()
        CE_loss.cuda()
        Dice_.cuda()

    # To load a pre-trained model
    '''try:
        net = torch.load('modelName')
        print("--------model restored--------")
    except:
        print("--------model not restored--------")
        pass'''

    optimizer = Adam(net.parameters(), lr=lr, betas=(0.9, 0.99), amsgrad=False)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer,
                                                           mode='max',
                                                           patience=4,
                                                           verbose=True,
                                                           factor=10**-0.5)

    BestDice, BestEpoch = 0, 0

    d1Train = []
    d1Val = []
    Losses = []

    print("~~~~~~~~~~~ Starting the training ~~~~~~~~~~")
    for i in range(epoch):
        net.train()
        lossTrain = []
        d1TrainTemp = []

        totalImages = len(train_loader)

        for j, data in enumerate(train_loader):

            image_f, image_i, image_o, image_w, labels, img_names = data

            # Be sure your data here is between [0,1]
            image_f = image_f.type(torch.FloatTensor)
            image_i = image_i.type(torch.FloatTensor)
            image_o = image_o.type(torch.FloatTensor)
            image_w = image_w.type(torch.FloatTensor)

            labels = labels.numpy()
            idx = np.where(labels > 0.0)
            labels[idx] = 1.0
            labels = torch.from_numpy(labels)
            labels = labels.type(torch.FloatTensor)

            optimizer.zero_grad()
            MRI = to_var(torch.cat((image_f, image_i, image_o, image_w),
                                   dim=1))

            Segmentation = to_var(labels)

            target_dice = to_var(torch.ones(1))

            net.zero_grad()

            segmentation_prediction = net(MRI)
            predClass_y = softMax(segmentation_prediction)

            Segmentation_planes = getOneHotSegmentation(Segmentation)
            segmentation_prediction_ones = predToSegmentation(predClass_y)

            # It needs the logits, not the softmax
            Segmentation_class = getTargetSegmentation(Segmentation)

            CE_loss_ = CE_loss(segmentation_prediction, Segmentation_class)

            # Compute the Dice (so far in a 2D-basis)
            DicesB, DicesF = Dice_(segmentation_prediction_ones,
                                   Segmentation_planes)
            DiceB = DicesToDice(DicesB)
            DiceF = DicesToDice(DicesF)

            loss = CE_loss_

            loss.backward()
            optimizer.step()

            lossTrain.append(loss.data[0])

            printProgressBar(j + 1,
                             totalImages,
                             prefix="[Training] Epoch: {} ".format(i),
                             length=15,
                             suffix=" Mean Dice: {:.4f},".format(
                                 DiceF.data[0]))

        printProgressBar(totalImages,
                         totalImages,
                         done="[Training] Epoch: {}, LossG: {:.4f}".format(
                             i, np.mean(lossTrain)))
        # Save statistics
        Losses.append(np.mean(lossTrain))
        d1 = inference(net, val_loader, batch_size, i)
        d1Val.append(d1)
        d1Train.append(np.mean(d1TrainTemp).data[0])

        mainPath = '../Results/Statistics/' + modelName

        directory = mainPath
        if not os.path.exists(directory):
            os.makedirs(directory)

        np.save(os.path.join(directory, 'Losses.npy'), Losses)
        np.save(os.path.join(directory, 'd1Val.npy'), d1Val)
        np.save(os.path.join(directory, 'd1Train.npy'), d1Train)

        currentDice = d1[0].numpy()

        print("[val] DSC: {:.4f} ".format(d1[0]))

        if currentDice > BestDice:
            BestDice = currentDice

            BestEpoch = i
            if currentDice > 0.75:
                print(
                    "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Saving best model..... ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~"
                )
                if not os.path.exists(model_dir):
                    os.makedirs(model_dir)
                torch.save(
                    net, os.path.join(model_dir, "Best_" + modelName + ".pkl"))
                saveImages(net, val_loader_save_images, batch_size_val_save, i,
                           modelName)

        # Two ways of decay the learning rate:
        if i % (BestEpoch + 10):
            for param_group in optimizer.param_groups:
                param_group['lr'] = lr
Ejemplo n.º 6
0
    def train(self,
              targs,
              inps,
              mode='online',
              epochs=500,
              par=None,
              outs=None,
              Jrout=None,
              track=False):
        '''
            This is the main function of the model: is used to trained the system
            given a target and and input. Two training mode can be selected:
            (offline, online). The first uses the exact likelihood gradient (which
            is non local in time, thus non biologically-plausible), the second is
            the online approx. of the gradient as descrived in the article.

            Args:
                targ: numpy.array of shape (N, T), where N is the number of neurons
                      in the network and T is the time length of the sequence.

                inp : numpy.array of shape (N, T) that collects the input signal
                      to neurons.

            Keywords:
                mode : (default: online) The training mode to use, either 'offline'
                       or 'online'.

                epochs: (default: 500) The number of epochs of training.

                par   : (default: None) Optional different dictionary collecting
                        training parameters: {dv, alpha, alpha_rout, beta_ro, offT}.
                        If not provided defaults to the parameter dictionary of
                        the model.

                out   : (default: None) Output target trajectories, numpy.array
                        of shape (K, T), where K is the dimension of the output
                        trajectories. This parameter should be specified if either
                        Jrout != None or track is True.

                Jrout : (default: None) Pre-trained readout connection matrix.
                        If not provided, a novel matrix is built and trained
                        simultaneously with the recurrent connections training.
                        If Jrout is provided, the out parameter should be specified
                        as it is needed to compute output error.

                track : (default: None) Flag to signal whether to track the evolution
                        of output MSE over training epochs. If track is True then
                        the out parameters should be specified as it is needed to
                        compute output error.

        '''
        par = self.par if par is None else par

        dv = par['dv']

        itau_m = self.itau_m
        itau_s = self.itau_s

        sigm = self._sigm

        alpha = par['alpha']
        alpha_rout = par['alpha_rout']
        beta_ro = par['beta_ro']
        offT = par['offT']

        self.J = np.random.normal(0.,
                                  .1 / np.sqrt(self.N),
                                  size=(self.N, self.N))
        self.J[:, :self.N // 2] = np.abs(self.J[:, :self.N // 2])
        self.J[:, self.N // 2:] = -np.abs(self.J[:, self.N // 2:])
        np.fill_diagonal(self.J, 0.)

        ndxe = self.J > 0
        ndxi = self.J < 0

        track = np.zeros((epochs, 2)) if track else None
        Tmax = np.shape(targs[0])[-1]

        opt_recE = SimpleGradient(alpha=alpha)
        opt_recI = SimpleGradient(alpha=alpha)

        if Jrout is None:
            S_rout = [ut.sfilter(targ, itau=beta_ro) for targ in targs]
            J_rout = np.random.normal(0.,
                                      0.01,
                                      size=(outs[0].shape[0], self.N))
            opt = Adam(alpha=alpha_rout,
                       drop=0.9,
                       drop_time=epochs // 5 *
                       Tmax if mode == 'online' else 1e10)

        else:
            J_rout = Jrout

        for epoch in trange(epochs,
                            leave=False,
                            desc='Training {}'.format(mode)):
            if Jrout is None:
                ut.shuffle((inps, outs, targs, S_rout))
            else:
                ut.shuffle((inps, targs))

            if Jrout is None:
                for out, s_rout in zip(outs, S_rout):
                    # Here we train the readout
                    dJrout = (out - J_rout @ s_rout) @ s_rout.T
                    J_rout = opt.step(J_rout, dJrout)

            # Here we train the network
            for inp, targ in zip(inps, targs):
                self.S[:, 0] = targ[:, 0].copy()
                self.S_hat[:, 0] = self.S[:, 0] * itau_s

                dHe = np.zeros((self.N, self.N))
                dHi = np.zeros((self.N, self.N))

                dJe = dJi = 0.

                for t in range(Tmax - 1):
                    Jp = ndxe * self.J
                    Jm = -(ndxi * self.J)

                    self.S_hat[:,
                               t] = self.S_hat[:,
                                               t - 1] * itau_s + targ[:, t] * (
                                                   1. - itau_s
                                               ) if t > 0 else self.S_hat[:, 0]

                    self.H [:, t + 1] = self.H [:, t] * (1. - itau_m) +\
                                        itau_m * ((self.Ve - self.H[:, t]).reshape (self.N, 1) / abs(self.Ve) * Jp @ self.S_hat [:, t] +
                                                  (self.Vi - self.H[:, t]).reshape (self.N, 1) / abs(self.Vi) * Jm @ self.S_hat [:, t] +
                                                  inp [:, t] + self.h [:, t]) +\
                                        self.Jreset @ targ [:, t]

                    dHe [...] = dHe * (1. - itau_m * (1 + ((Jp + Jm) @ self.S_hat [:, t]).reshape (self.N, 1) / abs (self.Ve))) +\
                                     itau_m * ((self.Ve - self.H[:, t + 1]).reshape (self.N, 1) / abs (self.Ve)
                                                * ndxe * self.S_hat[:, t].reshape (1, self.N))

                    dHi [...] = dHi * (1. - itau_m * (1 + ((Jp + Jm) @ self.S_hat [:, t]).reshape (self.N, 1) / abs (self.Vi))) +\
                                     itau_m * ((self.Vi - self.H[:, t + 1]).reshape(self.N, 1) / abs (self.Vi)
                                                * ndxi * self.S_hat[:, t].reshape(1, self.N))

                    if mode == 'online':
                        dJe = (targ[:, t + 1] - sigm(
                            self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHe
                        dJi = (targ[:, t + 1] - sigm(
                            self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHi

                        Jp = opt_recE.step(Jp, dJe * ndxe)
                        Jm = opt_recI.step(Jm, dJi * ndxi)

                        Jp = np.maximum(Jp, 0.)
                        Jm = np.maximum(Jm, 0.)

                        self.J = Jp - Jm

                        np.fill_diagonal(self.J, 0.)

                    elif mode == 'offline':
                        dJe += (targ[:, t + 1] - sigm(
                            self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHe
                        dJi += (targ[:, t + 1] - sigm(
                            self.H[:, t + 1], dv=dv)).reshape(self.N, 1) * dHi

                if mode == 'offline':
                    Jp = opt_recE.step(Jp, dJe * ndxe)
                    Jm = opt_recI.step(Jm, dJi * ndxi)

                    Jp = np.maximum(Jp, 0.)
                    Jm = np.maximum(Jm, 0.)

                    self.J = Jp - Jm

                    np.fill_diagonal(self.J, 0.)

            # Here we track MSE
            if track is not None:
                S_gen = self.compute(inp, init=np.zeros(self.N))
                track[epoch, 0] = np.mean(
                    (out -
                     J_rout @ ut.sfilter(S_gen, itau=beta_ro))[:, offT:]**2.)
                track[epoch,
                      1] = np.sum(np.abs(targ - S_gen)) / (self.N * self.T)

        return (J_rout, track) if Jrout is None else track