Exemple #1
0
def trainValidateSegmentation(args):
    '''
    Main function for trainign and validation
    :param args: global arguments
    :return: None
    '''
    # check if processed data file exists or not
    if not os.path.isfile(args.cached_data_file):
        dataLoad = ld.LoadData(args.data_dir, args.classes, args.cached_data_file)
        data = dataLoad.processData()
        if data is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        data = pickle.load(open(args.cached_data_file, "rb"))

    q = args.q
    p = args.p
    # load the model
    if not args.decoder:
        model = net.ESPNet_Encoder(args.classes, p=p, q=q)
        args.savedir = args.savedir + '_enc_' + str(p) + '_' + str(q) + '/'
    else:
        model = net.ESPNet(args.classes, p=p, q=q, encoderFile=args.pretrained)
        args.savedir = args.savedir + '_dec_' + str(p) + '_' + str(q) + '/'

    if args.onGPU:
        model = model.cuda()

    # create the directory if not exist
    if not os.path.exists(args.savedir):
        os.mkdir(args.savedir)

    if args.visualizeNet:
        x = Variable(torch.randn(1, 3, args.inWidth, args.inHeight))

        if args.onGPU:
            x = x.cuda()

        y = model.forward(x)
        g = viz.make_dot(y)
        g.render(args.savedir + 'model.png', view=False)

    total_paramters = netParams(model)
    print('Total network parameters: ' + str(total_paramters))

    # define optimization criteria
    weight = torch.from_numpy(data['classWeights']) # convert the numpy array to torch
    if args.onGPU:
        weight = weight.cuda()

    criteria = CrossEntropyLoss2d(weight) #weight

    if args.onGPU:
        criteria = criteria.cuda()

    print('Data statistics')
    print(data['mean'], data['std'])
    print(data['classWeights'])

    #compose the data with transforms
    trainDataset_main = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1024, 512),
        myTransforms.RandomCropResize(32),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64).
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale1 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1536, 768), # 1536, 768
        myTransforms.RandomCropResize(100),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale2 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1280, 720), # 1536, 768
        myTransforms.RandomCropResize(100),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale3 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(768, 384),
        myTransforms.RandomCropResize(32),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale4 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(512, 256),
        #myTransforms.RandomCropResize(20),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64).
        myTransforms.ToTensor(args.scaleIn),
        #
    ])


    valDataset = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1024, 512),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    # since we training from scratch, we create data loaders at different scales
    # so that we can generate more augmented data and prevent the network from overfitting

    trainLoader = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_main),
        batch_size=args.batch_size + 2, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale1 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale1),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale2 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale2),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale3 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale3),
        batch_size=args.batch_size + 4, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale4 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale4),
        batch_size=args.batch_size + 4, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    valLoader = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['valIm'], data['valAnnot'], transform=valDataset),
        batch_size=args.batch_size + 4, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    if args.onGPU:
        cudnn.benchmark = True

    start_epoch = 0

    if args.resume:
        if os.path.isfile(args.resumeLoc):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resumeLoc)
            start_epoch = checkpoint['epoch']
            #args.lr = checkpoint['lr']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t%s\t%s\t%s\t%s\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=5e-4)
    # we step the loss by 2 after step size is reached
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_loss, gamma=0.5)


    for epoch in range(start_epoch, args.max_epochs):

        scheduler.step(epoch)
        lr = 0
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
        print("Learning rate: " +  str(lr))

        # train for one epoch
        # We consider 1 epoch with all the training data (at different scales)
        train(args, trainLoader_scale1, model, criteria, optimizer, epoch)
        train(args, trainLoader_scale2, model, criteria, optimizer, epoch)
        train(args, trainLoader_scale4, model, criteria, optimizer, epoch)
        train(args, trainLoader_scale3, model, criteria, optimizer, epoch)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr = train(args, trainLoader, model, criteria, optimizer, epoch)

        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = val(args, valLoader, model, criteria)
        
            
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': str(model),
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossTr': lossTr,
            'lossVal': lossVal,
            'iouTr': mIOU_tr,
            'iouVal': mIOU_val,
            'lr': lr
        }, args.savedir + 'checkpoint.pth.tar')

        #save the model also
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        torch.save(model.state_dict(), model_file_name)

        

        with open(args.savedir + 'acc_' + str(epoch) + '.txt', 'w') as log:
            log.write("\nEpoch: %d\t Overall Acc (Tr): %.4f\t Overall Acc (Val): %.4f\t mIOU (Tr): %.4f\t mIOU (Val): %.4f" % (epoch, overall_acc_tr, overall_acc_val, mIOU_tr, mIOU_val))
            log.write('\n')
            log.write('Per Class Training Acc: ' + str(per_class_acc_tr))
            log.write('\n')
            log.write('Per Class Validation Acc: ' + str(per_class_acc_val))
            log.write('\n')
            log.write('Per Class Training mIOU: ' + str(per_class_iu_tr))
            log.write('\n')
            log.write('Per Class Validation mIOU: ' + str(per_class_iu_val))

        logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("Epoch : " + str(epoch) + ' Details')
        print("\nEpoch No.: %d\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f" % (epoch, lossTr, lossVal, mIOU_tr, mIOU_val))
    logger.close()
Exemple #2
0
def trainValidateSegmentation(args):

    print('Data file: ' + str(args.cached_data_file))
    print(args)

    # check if processed data file exists or not
    if not os.path.isfile(args.cached_data_file):
        dataLoader = ld.LoadData(args.data_dir, args.data_dir_val,
                                 args.classes, args.cached_data_file)
        data = dataLoader.processData()
        if data is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        data = pickle.load(open(args.cached_data_file, "rb"))
    print('=> Loading the model')
    model = net.ESPNet(classes=args.classes, channels=args.channels)
    args.savedir = args.savedir + os.sep

    if args.onGPU:
        model = model.cuda()

    # create the directory if not exist
    if not os.path.exists(args.savedir):
        os.mkdir(args.savedir)

    if args.onGPU:
        model = model.cuda()

    if args.visualizeNet:
        import VisualizeGraph as viz
        x = Variable(
            torch.randn(1, args.channels, args.inDepth, args.inWidth,
                        args.inHeight))

        if args.onGPU:
            x = x.cuda()

        y = model(x, (128, 128, 128))  #, _, _
        g = viz.make_dot(y)
        g.render(args.savedir + os.sep + 'model', view=False)

    total_paramters = 0
    for parameter in model.parameters():
        i = len(parameter.size())
        p = 1
        for j in range(i):
            p *= parameter.size(j)
        total_paramters += p

    print('Parameters: ' + str(total_paramters))

    # define optimization criteria
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch <- Sachin
    print('Class Imbalance Weights')
    print(weight)
    criteria = torch.nn.CrossEntropyLoss(weight)
    if args.onGPU:
        criteria = criteria.cuda()

    # We train at three different resolutions (144x144x144, 96x96x96 and 128x128x128)
    # and validate at one resolution (128x128x128)
    trainDatasetA = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed(dimA=144, dimB=144, dimC=144),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor(args.scaleIn),
    ])

    trainDatasetB = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed(dimA=96, dimB=96, dimC=96),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor(args.scaleIn),
    ])

    trainDatasetC = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed(dimA=args.inWidth,
                                  dimB=args.inHeight,
                                  dimC=args.inDepth),
        myTransforms.RandomFlip(),
        myTransforms.ToTensor(args.scaleIn),
    ])

    valDataset = myTransforms.Compose([
        myTransforms.MinMaxNormalize(),
        myTransforms.ScaleToFixed(dimA=args.inWidth,
                                  dimB=args.inHeight,
                                  dimC=args.inDepth),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainLoaderA = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'],
                               data['trainAnnot'],
                               transform=trainDatasetA),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=False)  #disabling pin memory because swap usage is high
    trainLoaderB = torch.utils.data.DataLoader(myDataLoader.MyDataset(
        data['trainIm'], data['trainAnnot'], transform=trainDatasetB),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=False)
    trainLoaderC = torch.utils.data.DataLoader(myDataLoader.MyDataset(
        data['trainIm'], data['trainAnnot'], transform=trainDatasetC),
                                               batch_size=args.batch_size,
                                               shuffle=True,
                                               num_workers=args.num_workers,
                                               pin_memory=False)

    valLoader = torch.utils.data.DataLoader(myDataLoader.MyDataset(
        data['valIm'], data['valAnnot'], transform=valDataset),
                                            batch_size=1,
                                            shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=False)

    # define the optimizer
    optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad,
                                        model.parameters()),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=2e-4)

    if args.onGPU == True:
        cudnn.benchmark = True

    start_epoch = 0
    stored_loss = 100000000.0
    if args.resume:
        if os.path.isfile(args.resumeLoc):
            print("=> loading checkpoint '{}'".format(args.resumeLoc))
            checkpoint = torch.load(args.resumeLoc)
            start_epoch = checkpoint['epoch']
            stored_loss = checkpoint['stored_loss']
            model.load_state_dict(checkpoint['state_dict'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t%s\t%s\t%s\t%s\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val'))
        logger.flush()
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Arguments: %s" % (str(args)))
        logger.write("\n Parameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t%s\t%s\t%s\t%s\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val'))
        logger.flush()

    # reduce the learning rate by 0.5 after every 100 epochs
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=args.step_loss,
                                                gamma=0.5)  #40
    best_val_acc = 0

    loader_idxs = [
        0, 1, 2
    ]  # Three loaders at different resolutions are mapped to three indexes
    for epoch in range(start_epoch, args.max_epochs):
        # step the learning rate
        scheduler.step(epoch)
        lr = 0
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
        print('Running epoch {} with learning rate {:.5f}'.format(epoch, lr))

        if epoch > 0:
            # shuffle the loaders
            np.random.shuffle(loader_idxs)

        for l_id in loader_idxs:
            if l_id == 0:
                train(args, trainLoaderA, model, criteria, optimizer, epoch)
            elif l_id == 1:
                train(args, trainLoaderB, model, criteria, optimizer, epoch)
            else:
                lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr = \
                    train(args, trainLoaderC, model, criteria, optimizer, epoch)

        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = val(
            args, valLoader, model, criteria)

        print('saving checkpoint')  ## added
        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lossTr': lossTr,
                'lossVal': lossVal,
                'iouTr': mIOU_tr,
                'iouVal': mIOU_val,
                'stored_loss': stored_loss,
            }, args.savedir + '/checkpoint.pth.tar')

        # save the model also
        if mIOU_val >= best_val_acc:
            best_val_acc = mIOU_val
            torch.save(model.state_dict(), args.savedir + '/best_model.pth')

        with open(args.savedir + 'acc_' + str(epoch) + '.txt', 'w') as log:
            log.write(
                "\nEpoch: %d\t Overall Acc (Tr): %.4f\t Overall Acc (Val): %.4f\t mIOU (Tr): %.4f\t mIOU (Val): %.4f"
                % (epoch, overall_acc_tr, overall_acc_val, mIOU_tr, mIOU_val))
            log.write('\n')
            log.write('Per Class Training Acc: ' + str(per_class_acc_tr))
            log.write('\n')
            log.write('Per Class Validation Acc: ' + str(per_class_acc_val))
            log.write('\n')
            log.write('Per Class Training mIOU: ' + str(per_class_iu_tr))
            log.write('\n')
            log.write('Per Class Validation mIOU: ' + str(per_class_iu_val))

        logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.6f" %
                     (epoch, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("Epoch : " + str(epoch) + ' Details')
        print(
            "\nEpoch No.: %d\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f"
            % (epoch, lossTr, lossVal, mIOU_tr, mIOU_val))

    logger.close()