コード例 #1
0
def trainValSegmentation(args):
    if not os.path.isfile(args.cached_data_file):
        dataLoader = ld.LoadData(args.data_dir, args.classes, args.attrClasses,
                                 args.cached_data_file)
        if dataLoader is None:
            print("Error while cacheing the data.")
            exit(-1)
        data = dataLoader.processData()
    else:
        print("load cacheing data.")
        data = pickle.load(open(args.cached_data_file, 'rb'))
    # only unet for segmentation now.
    # model= unet.UNet(args.classes)
    # model = r18unet.ResNetUNet(args.classes)
    model = mobileunet.MobileUNet(args.classes)
    print("UNet done...")
    # if args.onGPU == True:
    model = model.cuda()
    # devices_ids=[2,3], device_ids=range(2)
    # device = torch.device('cuda:' + str(devices_ids[0]))
    # model = model.to(device)
    if args.visNet == True:
        x = Variable(torch.randn(1, 3, args.inwidth, args.inheight))
        if args.onGPU == True:
            x = x.cuda()
        print("before forward...")
        y = model.forward(x)
        print("after forward...")
        g = viz.make_dot(y)
        # g1 = viz.make_dot(y1)
        g.render(args.save_dir + '/model', view=False)
    model = torch.nn.DataParallel(model)
    n_param = sum([np.prod(param.size()) for param in model.parameters()])
    print('network parameters: ' + str(n_param))

    #define optimization criteria
    weight = torch.from_numpy(data['classWeights'])
    print(weight)
    if args.onGPU == True:
        weight = weight.cuda()
    criteria = CrossEntropyLoss2d(weight)
    # if args.onGPU == True:
    # 	criteria = criteria.cuda()

    trainDatasetNoZoom = myTransforms.Compose([
        myTransforms.RandomCropResize(args.inwidth, args.inheight),
        # myTransforms.RandomHorizontalFlip(),
        myTransforms.ToTensor(args.scaleIn)
    ])
    trainDatasetWithZoom = myTransforms.Compose([
        # myTransforms.Zoom(512,512),
        myTransforms.RandomCropResize(args.inwidth, args.inheight),
        myTransforms.RandomHorizontalFlip(),
        myTransforms.ToTensor(args.scaleIn)
    ])
    valDataset = myTransforms.Compose([
        myTransforms.RandomCropResize(args.inwidth, args.inheight),
        myTransforms.ToTensor(args.scaleIn)
    ])
    trainLoaderNoZoom = torch.utils.data.DataLoader(
        ld.MyDataset(data['trainIm'],
                     data['trainAnnot'],
                     transform=trainDatasetNoZoom),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)
    trainLoaderWithZoom = torch.utils.data.DataLoader(
        ld.MyDataset(data['trainIm'],
                     data['trainAnnot'],
                     transform=trainDatasetWithZoom),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)
    valLoader = torch.utils.data.DataLoader(ld.MyDataset(data['valIm'],
                                                         data['valAnnot'],
                                                         transform=valDataset),
                                            batch_size=args.batch_size_val,
                                            shuffle=True,
                                            num_workers=args.num_workers,
                                            pin_memory=True)

    #define the optimizer
    optimizer = torch.optim.Adam(model.parameters(),
                                 args.lr, (0.9, 0.999),
                                 eps=1e-08,
                                 weight_decay=2e-4)
    # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, momentum=0.99, weight_decay=5e-4)
    # optimizer = torch.optim.SGD([
    #        {'params': [param for name, param in model.named_parameters() if name[-4:] == 'bias'],
    #         'lr': 2 * args.lr},
    #        {'params': [param for name, param in model.named_parameters() if name[-4:] != 'bias'],
    #         'lr': args.lr, 'weight_decay': 5e-4}
    #    ], momentum=0.99)

    if args.onGPU == True:
        cudnn.benchmark = True
    start_epoch = 0
    if args.resume:
        if os.path.isfile(args.resumeLoc):
            print("=> loading checkpoint '{}'".format(args.resumeLoc))
            checkpoint = torch.load(args.resumeLoc)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch{})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resumeLoc))

    logfileLoc = args.save_dir + os.sep + args.logFile
    print(logfileLoc)
    if os.path.isfile(logfileLoc):
        logger = open(logfileLoc, 'a')
        logger.write("parameters: %s" % (str(n_param)))
        logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s\t" %
                     ('Epoch', 'Loss(Tr)', 'Loss(val)', 'Overall acc(Tr)',
                      'Overall acc(val)', 'mIOU (tr)', 'mIOU (val'))
        logger.flush()
    else:
        logger = open(logfileLoc, 'w')
        logger.write("Parameters: %s" % (str(n_param)))
        logger.write("\n%s\t%s\t%s\t%s\t%s\t%s\t%s\t" %
                     ('Epoch', 'Loss(Tr)', 'Loss(val)', 'Overall acc(Tr)',
                      'Overall acc(val)', 'mIOU (tr)', 'mIOU (val'))
        logger.flush()

    #lr scheduler
    scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                     milestones=[30, 60, 90],
                                                     gamma=0.1)
    best_model_acc = 0
    for epoch in range(start_epoch, args.max_epochs):
        scheduler.step(epoch)
        lr = 0
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
        # train(args,trainLoaderWithZoom,model,criteria,optimizer,epoch)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr = train(
            args, trainLoaderNoZoom, model, criteria, optimizer, epoch)
        # print(per_class_acc_tr,per_class_iu_tr)
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = val(
            args, valLoader, model, criteria)

        #save_checkpoint
        torch.save(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lossTr': lossTr,
                'lossVal': lossVal,
                'iouTr': mIOU_tr,
                'iouVal': mIOU_val,
            }, args.save_dir + '/checkpoint.pth.tar')

        #save model also
        # if overall_acc_val > best_model_acc:
        # 	best_model_acc = overall_acc_val
        model_file_name = args.save_dir + '/model_' + str(epoch + 1) + '.pth'
        torch.save(model.state_dict(), model_file_name)
        with open('../acc/acc_' + str(epoch) + '.txt', 'w') as log:
            log.write(
                "\nEpoch: %d\t Overall Acc (Tr): %.4f\t Overall Acc (Val): %.4f\t mIOU (Tr): %.4f\t mIOU (Val): %.4f"
                % (epoch, overall_acc_tr, overall_acc_val, mIOU_tr, mIOU_val))
            log.write('\n')
            log.write('Per Class Training Acc: ' + str(per_class_acc_tr))
            log.write('\n')
            log.write('Per Class Validation Acc: ' + str(per_class_acc_val))
            log.write('\n')
            log.write('Per Class Training mIOU: ' + str(per_class_iu_tr))
            log.write('\n')
            log.write('Per Class Validation mIOU: ' + str(per_class_iu_val))

        logger.write(
            "\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.6f" %
            (epoch, lossTr, lossVal, overall_acc_tr, overall_acc_val, mIOU_tr,
             mIOU_val, lr))
        logger.flush()
        print("Epoch : " + str(epoch) + ' Details')
        print(
            "\nEpoch No.: %d\tTrain Loss = %.4f\tVal Loss = %.4f\t Train acc = %.4f\t Val acc = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f"
            % (epoch, lossTr, lossVal, overall_acc_tr, overall_acc_val,
               mIOU_tr, mIOU_val))

    logger.close()
コード例 #2
0
def trainValidateSegmentation(args):
    '''
    Main function for trainign and validation
    :param args: global arguments
    :return: None
    '''
    # check if processed data file exists or not
    if not os.path.isfile(args.cached_data_file):
        dataLoad = ld.LoadData(args.data_dir, args.classes, args.cached_data_file)
        data = dataLoad.processData()
        if data is None:
            print('Error while pickling data. Please check.')
            exit(-1)
    else:
        data = pickle.load(open(args.cached_data_file, "rb"))

    q = args.q
    p = args.p
    # load the model
    if not args.decoder:
        model = net.ESPNet_Encoder(args.classes, p=p, q=q)
        args.savedir = args.savedir + '_enc_' + str(p) + '_' + str(q) + '/'
    else:
        model = net.ESPNet(args.classes, p=p, q=q, encoderFile=args.pretrained)
        args.savedir = args.savedir + '_dec_' + str(p) + '_' + str(q) + '/'

    if args.onGPU:
        model = model.cuda()

    # create the directory if not exist
    if not os.path.exists(args.savedir):
        os.mkdir(args.savedir)

    if args.visualizeNet:
        x = Variable(torch.randn(1, 3, args.inWidth, args.inHeight))

        if args.onGPU:
            x = x.cuda()

        y = model.forward(x)
        g = viz.make_dot(y)
        g.render(args.savedir + 'model.png', view=False)

    total_paramters = netParams(model)
    print('Total network parameters: ' + str(total_paramters))

    # define optimization criteria
    weight = torch.from_numpy(data['classWeights']) # convert the numpy array to torch
    if args.onGPU:
        weight = weight.cuda()

    criteria = CrossEntropyLoss2d(weight) #weight

    if args.onGPU:
        criteria = criteria.cuda()

    print('Data statistics')
    print(data['mean'], data['std'])
    print(data['classWeights'])

    #compose the data with transforms
    trainDataset_main = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1024, 512),
        myTransforms.RandomCropResize(32),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64).
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale1 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1536, 768), # 1536, 768
        myTransforms.RandomCropResize(100),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale2 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1280, 720), # 1536, 768
        myTransforms.RandomCropResize(100),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale3 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(768, 384),
        myTransforms.RandomCropResize(32),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    trainDataset_scale4 = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(512, 256),
        #myTransforms.RandomCropResize(20),
        myTransforms.RandomFlip(),
        #myTransforms.RandomCrop(64).
        myTransforms.ToTensor(args.scaleIn),
        #
    ])


    valDataset = myTransforms.Compose([
        myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Scale(1024, 512),
        myTransforms.ToTensor(args.scaleIn),
        #
    ])

    # since we training from scratch, we create data loaders at different scales
    # so that we can generate more augmented data and prevent the network from overfitting

    trainLoader = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_main),
        batch_size=args.batch_size + 2, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale1 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale1),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale2 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale2),
        batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale3 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale3),
        batch_size=args.batch_size + 4, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    trainLoader_scale4 = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'], data['trainAnnot'], transform=trainDataset_scale4),
        batch_size=args.batch_size + 4, shuffle=True, num_workers=args.num_workers, pin_memory=True)

    valLoader = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['valIm'], data['valAnnot'], transform=valDataset),
        batch_size=args.batch_size + 4, shuffle=False, num_workers=args.num_workers, pin_memory=True)

    if args.onGPU:
        cudnn.benchmark = True

    start_epoch = 0

    if args.resume:
        if os.path.isfile(args.resumeLoc):
            print("=> loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resumeLoc)
            start_epoch = checkpoint['epoch']
            #args.lr = checkpoint['lr']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})"
                .format(args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))
    

    logFileLoc = args.savedir + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write("\n%s\t%s\t%s\t%s\t%s\t" % ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val'))
    logger.flush()

    optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=5e-4)
    # we step the loss by 2 after step size is reached
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=args.step_loss, gamma=0.5)


    for epoch in range(start_epoch, args.max_epochs):

        scheduler.step(epoch)
        lr = 0
        for param_group in optimizer.param_groups:
            lr = param_group['lr']
        print("Learning rate: " +  str(lr))

        # train for one epoch
        # We consider 1 epoch with all the training data (at different scales)
        train(args, trainLoader_scale1, model, criteria, optimizer, epoch)
        train(args, trainLoader_scale2, model, criteria, optimizer, epoch)
        train(args, trainLoader_scale4, model, criteria, optimizer, epoch)
        train(args, trainLoader_scale3, model, criteria, optimizer, epoch)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr = train(args, trainLoader, model, criteria, optimizer, epoch)

        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = val(args, valLoader, model, criteria)
        
            
        save_checkpoint({
            'epoch': epoch + 1,
            'arch': str(model),
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'lossTr': lossTr,
            'lossVal': lossVal,
            'iouTr': mIOU_tr,
            'iouVal': mIOU_val,
            'lr': lr
        }, args.savedir + 'checkpoint.pth.tar')

        #save the model also
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        torch.save(model.state_dict(), model_file_name)

        

        with open(args.savedir + 'acc_' + str(epoch) + '.txt', 'w') as log:
            log.write("\nEpoch: %d\t Overall Acc (Tr): %.4f\t Overall Acc (Val): %.4f\t mIOU (Tr): %.4f\t mIOU (Val): %.4f" % (epoch, overall_acc_tr, overall_acc_val, mIOU_tr, mIOU_val))
            log.write('\n')
            log.write('Per Class Training Acc: ' + str(per_class_acc_tr))
            log.write('\n')
            log.write('Per Class Validation Acc: ' + str(per_class_acc_val))
            log.write('\n')
            log.write('Per Class Training mIOU: ' + str(per_class_iu_tr))
            log.write('\n')
            log.write('Per Class Validation mIOU: ' + str(per_class_iu_val))

        logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.7f" % (epoch, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("Epoch : " + str(epoch) + ' Details')
        print("\nEpoch No.: %d\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f" % (epoch, lossTr, lossVal, mIOU_tr, mIOU_val))
    logger.close()
コード例 #3
0
def trainValidateSegmentation(args):
    # check if processed data file exists or not
    if not os.path.isfile(args.cached_data_file):
        dataLoader = ld.LoadData(args.data_dir, args.classes,
                                 args.cached_data_file)
        if dataLoader is None:
            print('Error while processing the data. Please check')
            exit(-1)
        data = dataLoader.processData()
    else:
        data = pickle.load(open(args.cached_data_file, "rb"))

    if args.modelType == 'C1':
        model = net.ResNetC1(args.classes)
    elif args.modelType == 'D1':
        model = net.ResNetD1(args.classes)
    else:
        print('Please select the correct model. Exiting!!')
        exit(-1)

        args.savedir = args.savedir + args.modelType + '/'

    if args.onGPU == True:
        model = model.cuda()

    # create the directory if not exist
    if not os.path.exists(args.savedir):
        os.mkdir(args.savedir)

    if args.onGPU == True:
        model = model.cuda()

    if args.visualizeNet == True:
        x = Variable(torch.randn(1, 3, args.inWidth, args.inHeight))

        if args.onGPU == True:
            x = x.cuda()

        y = model.forward(x)
        g = viz.make_dot(y)
        g.render(args.savedir + '/model.png', view=False)

    n_param = sum([np.prod(param.size()) for param in model.parameters()])
    print('Network parameters: ' + str(n_param))

    # define optimization criteria
    print('Weights to handle class-imbalance')
    weight = torch.from_numpy(
        data['classWeights'])  # convert the numpy array to torch
    print(weight)
    if args.onGPU == True:
        weight = weight.cuda()

    criteria = CrossEntropyLoss2d(weight)  # weight

    if args.onGPU == True:
        criteria = criteria.cuda()

    trainDatasetNoZoom = myTransforms.Compose([
        # myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.RandomCropResize(20),
        myTransforms.RandomHorizontalFlip(),
        myTransforms.ToTensor(args.scaleIn)
    ])

    trainDatasetWithZoom = myTransforms.Compose([
        # myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.Zoom(512, 512),
        myTransforms.RandomCropResize(20),
        myTransforms.RandomHorizontalFlip(),
        myTransforms.ToTensor(args.scaleIn)
    ])

    valDataset = myTransforms.Compose([
        # myTransforms.Normalize(mean=data['mean'], std=data['std']),
        myTransforms.ToTensor(args.scaleIn)
    ])

    trainLoaderNoZoom = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'],
                               data['trainAnnot'],
                               transform=trainDatasetNoZoom),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    trainLoaderWithZoom = torch.utils.data.DataLoader(
        myDataLoader.MyDataset(data['trainIm'],
                               data['trainAnnot'],
                               transform=trainDatasetWithZoom),
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        pin_memory=True)

    valLoader = torch.utils.data.DataLoader(myDataLoader.MyDataset(
        data['valIm'], data['valAnnot'], transform=valDataset),
                                            batch_size=args.batch_size,
                                            shuffle=False,
                                            num_workers=args.num_workers,
                                            pin_memory=True)

    # define the optimizer
    # optimizer = torch.optim.Adam(model.parameters(), args.lr, (0.9, 0.999), eps=1e-08, weight_decay=2e-4)
    optimizer = torch.optim.SGD(model.parameters(),
                                lr=args.lr,
                                momentum=0.9,
                                weight_decay=5e-4)

    if args.onGPU == True:
        cudnn.benchmark = True

    start_epoch = 0

    if args.resume:
        if os.path.isfile(args.resumeLoc):
            print("=> loading checkpoint '{}'".format(args.resumeLoc))
            checkpoint = torch.load(args.resumeLoc)
            start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['state_dict'])
            print("=> loaded checkpoint '{}' (epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("=> no checkpoint found at '{}'".format(args.resume))

    logFileLoc = args.savedir + os.sep + args.logFile
    if os.path.isfile(logFileLoc):
        logger = open(logFileLoc, 'a')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t%s\t%s\t%s\t%s\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val'))
        logger.flush()
    else:
        logger = open(logFileLoc, 'w')
        logger.write("Parameters: %s" % (str(total_paramters)))
        logger.write(
            "\n%s\t%s\t%s\t%s\t%s\t" %
            ('Epoch', 'Loss(Tr)', 'Loss(val)', 'mIOU (tr)', 'mIOU (val'))
        logger.flush()

    #lr scheduler
    scheduler = torch.optim.lr_scheduler.StepLR(optimizer,
                                                step_size=args.step_loss,
                                                gamma=0.1)

    for epoch in range(start_epoch, args.max_epochs):
        scheduler.step(epoch)

        lr = 0
        for param_group in optimizer.param_groups:
            lr = param_group['lr']

        # run at zoomed images first
        train(args, trainLoaderWithZoom, model, criteria, optimizer, epoch)
        lossTr, overall_acc_tr, per_class_acc_tr, per_class_iu_tr, mIOU_tr = train(
            args, trainLoaderNoZoom, model, criteria, optimizer, epoch)
        # evaluate on validation set
        lossVal, overall_acc_val, per_class_acc_val, per_class_iu_val, mIOU_val = val(
            args, valLoader, model, criteria)

        save_checkpoint(
            {
                'epoch': epoch + 1,
                'arch': str(model),
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'lossTr': lossTr,
                'lossVal': lossVal,
                'iouTr': mIOU_tr,
                'iouVal': mIOU_val,
            }, args.savedir + '/checkpoint.pth.tar')

        # save the model also
        model_file_name = args.savedir + '/model_' + str(epoch + 1) + '.pth'
        torch.save(model.state_dict(), model_file_name)

        with open(args.savedir + 'acc_' + str(epoch) + '.txt', 'w') as log:
            log.write(
                "\nEpoch: %d\t Overall Acc (Tr): %.4f\t Overall Acc (Val): %.4f\t mIOU (Tr): %.4f\t mIOU (Val): %.4f"
                % (epoch, overall_acc_tr, overall_acc_val, mIOU_tr, mIOU_val))
            log.write('\n')
            log.write('Per Class Training Acc: ' + str(per_class_acc_tr))
            log.write('\n')
            log.write('Per Class Validation Acc: ' + str(per_class_acc_val))
            log.write('\n')
            log.write('Per Class Training mIOU: ' + str(per_class_iu_tr))
            log.write('\n')
            log.write('Per Class Validation mIOU: ' + str(per_class_iu_val))

        logger.write("\n%d\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f\t\t%.4f" %
                     (epoch, lossTr, lossVal, mIOU_tr, mIOU_val, lr))
        logger.flush()
        print("Epoch : " + str(epoch) + ' Details')
        print(
            "\nEpoch No.: %d\tTrain Loss = %.4f\tVal Loss = %.4f\t mIOU(tr) = %.4f\t mIOU(val) = %.4f"
            % (epoch, lossTr, lossVal, mIOU_tr, mIOU_val))

    logger.close()