Example #1
0
def val(args, model, dataloader, csv_path):
    print('start val!')
    # label_info = get_label_info(csv_path)
    with torch.no_grad():
        model.eval()
        precision_record = []
        hist = np.zeros((args.num_classes, args.num_classes))
        for i, (data, label) in enumerate(dataloader):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()

            # get RGB predict image
            predict = model(data).squeeze()
            predict = reverse_one_hot(predict)
            predict = np.array(predict)

            # get RGB label image
            label = label.squeeze()
            label = reverse_one_hot(label)
            label = np.array(label)
            # compute per pixel accuracy

            precision = compute_global_accuracy(predict, label)
            hist += fast_hist(label.flatten(), predict.flatten(), args.num_classes)

            # there is no need to transform the one-hot array to visual RGB array
            # predict = colour_code_segmentation(np.array(predict), label_info)
            # label = colour_code_segmentation(np.array(label), label_info)
            precision_record.append(precision)
        dice = np.mean(precision_record)
        miou = np.mean(per_class_iu(hist))
        print('precision per pixel for validation: %.3f' % dice)
        print('mIoU for validation: %.3f' % miou)
        return dice
Example #2
0
def val(args, model, val_img_path, val_label_path, csv_path):
    print('start val!')
    dataset_val = ADE(val_img_path, val_label_path, scale=(args.crop_height, args.crop_width), mode='val')
    dataloader_val = DataLoader(
        dataset_val,
        # this has to be 1
        batch_size=1,
        shuffle=True,
        num_workers=args.num_workers
    )
    label_info = get_label_info(csv_path)
    with torch.no_grad():
        model.eval()
        precision_record = []
        for i, (data, label) in enumerate(dataloader_val):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            # get RGB predict image
            predict = model(data).squeeze()
            predict = reverse_one_hot(predict)
            predict = colour_code_segmentation(np.array(predict), label_info)  # predict info
            # get RGB label image
            label = label.squeeze()
            label = reverse_one_hot(label)
            label = colour_code_segmentation(np.array(label), label_info)
            # compute per pixel accuracy
            precision = compute_global_accuracy(predict, label)
            precision_record.append(precision)
        dice = np.mean(precision_record)
        print('precision per pixel for validation: %.3f' % dice)
        return dice
Example #3
0
def val(args, model, dataloader, csv_path):
    print('start val!')
    label_info = get_label_info(csv_path)
    with torch.no_grad():
        model.eval()
        precision_record = []
        for i, (data, label) in enumerate(dataloader):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()

            # get RGB predict image
            predict = model(data).squeeze()
            predict = reverse_one_hot(predict)
            predict = colour_code_segmentation(np.array(predict), label_info)

            # get RGB label image
            label = label.squeeze()
            label = reverse_one_hot(label)
            label = colour_code_segmentation(np.array(label), label_info)
            # compute per pixel accuracy
            precision = compute_global_accuracy(predict, label)
            precision_record.append(precision)
        dice = np.mean(precision_record)
        print('precision per pixel for validation: %.3f' % dice)
        return dice
Example #4
0
def eval(model, dataloader, args, label_info):
    print('start test!')
    with torch.no_grad():
        model.eval()
        precision_record = []
        tq = tqdm.tqdm(total=len(dataloader) * args.batch_size)
        tq.set_description('test')
        for i, (data, label) in enumerate(dataloader):
            tq.update(args.batch_size)
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            predict = model(data).squeeze()
            predict = reverse_one_hot(predict)
            predict = colour_code_segmentation(np.array(predict), label_info)

            label = label.squeeze()
            label = reverse_one_hot(label)
            label = colour_code_segmentation(np.array(label), label_info)

            precision = compute_global_accuracy(predict, label)
            precision_record.append(precision)
        precision = np.mean(precision_record)
        tq.close()
        print('precision for test: %.3f' % precision)
        return precision
Example #5
0
def eval(model,dataloader, args, label_info):
    print('start test!')
    with torch.no_grad():
        model.eval()
        precision_record = []
        tq = tqdm.tqdm(total=len(dataloader) * args.batch_size)
        tq.set_description('test')
        hist = np.zeros((args.num_classes, args.num_classes))
        for i, (data, label) in enumerate(dataloader):
            tq.update(args.batch_size)
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            predict = model(data).squeeze()
            predict = reverse_one_hot(predict)
            predict = np.array(predict)
            # predict = colour_code_segmentation(np.array(predict), label_info)

            label = label.squeeze()
            label = reverse_one_hot(label)
            label = np.array(label)
            # label = colour_code_segmentation(np.array(label), label_info)

            precision = compute_global_accuracy(predict, label)
            hist += fast_hist(label.flatten(), predict.flatten(), args.num_classes)

            precision_record.append(precision)
        precision = np.mean(precision_record)
        miou = np.mean(per_class_iu(hist))
        tq.close()
        print('precision for test: %.3f' % precision)
        print('mIoU for validation: %.3f' % miou)
        return precision
Example #6
0
def eval(model, dataset, args, label_info):
    dataloader = DataLoader(
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=4,
    )
    print('start test!')
    with torch.no_grad():
        model.eval()
        precision_record = []
        tq = tqdm.tqdm(total=len(dataloader) * args.batch_size)
        tq.set_description('test')
        for i, (data, label) in enumerate(dataloader):
            tq.update(args.batch_size)
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            predict = model(data).squeeze()
            '''
            from PIL import Image
            import numpy as np
            
            temp = np.reshape(predict.detach().cpu().numpy(), (32, 640, 640))
            print(type(temp))
            temp = np.transpose(temp, [1, 2, 0])
            temp = np.asarray(temp[:, :])
            print(type(temp))
            for i in range(temp):
                for j in range(temp[0]):
                    k=max(j)
                    t=k.index()


            temp = np.asarray(temp < 0.05)
            new_im = Image.fromarray(temp)
            new_im.save('l.gif')
            print(predict)
            '''
            predict = reverse_one_hot(predict)
            predict = colour_code_segmentation(np.array(predict.cpu()),
                                               label_info)
            #print(predict)
            #cv2.imwrite("./result/"+dataset.image_name[i]+"_R.png",predict)
            label = label.squeeze()
            label = reverse_one_hot(label)
            label = colour_code_segmentation(np.array(label.cpu()), label_info)

            precision = compute_global_accuracy(predict, label)
            precision_record.append(precision)
        precision = np.mean(precision_record)
        tq.close()
        print('precision for test: %.3f' % precision)
        return precision
Example #7
0
def infer(model):
    print('start test!')
    label_np = array(Image.open('./data/test/Labels/672.png'))
    label = Variable(torch.from_numpy(label_np))
    img = Image.open('./data/test/Images/672.png')
    img_n = array(img)

    img_np = np.resize(img_n, (480, 480, 3))
    outputs = model(
        Variable(torch.from_numpy(img_np[np.newaxis, :].transpose(0, 3, 1,
                                                                  2)).float(),
                 volatile=True).cuda())
    # print outputs.size()
    outputs = outputs.squeeze()
    print outputs
    predict = reverse_one_hot(outputs)
    print predict
    label = label.squeeze()
    print label

    fig, ax = plt.subplots(1, 3)
    ax[0].imshow(img_np, cmap='gray')
    ax[1].imshow(predict)
    ax[2].imshow(label)
    plt.show()

    return 0
def seg_predict(image):
    global bise_model
    try:
        with torch.no_grad():
            bise_model.eval()
            h, w, _ = image.shape
            to_tensor = transforms.Compose([
                transforms.ToTensor(),
                transforms.Normalize((0.485, 0.456, 0.406),
                                     (0.229, 0.224, 0.225)),
            ])

            image = to_tensor(image)
            image = image.unsqueeze_(0)
            image = image.cuda()
            predict = bise_model(image).squeeze()
            predict = reverse_one_hot(predict)
            predict = np.array(predict)
            predict = np.resize(predict, [h, w])
            print(np.unique(predict))
            zzzz = cv2.cvtColor(np.uint8(predict), cv2.COLOR_GRAY2BGR)
            cv2.imwrite('./segmentation_image.png', zzzz)

            return predict
    except CvBridgeError as e:
        print(e)
Example #9
0
 def seg_callback(self, rgb):
     try:
         with torch.no_grad():
             self.model.eval()
             rgb = self.bridge.imgmsg_to_cv2(rgb, 'bgr8')
             self.to_tensor = transforms.Compose([
                 transforms.ToTensor(),
                 transforms.Normalize((0.485, 0.456, 0.406),
                                      (0.229, 0.224, 0.225)),
             ])
             #rgb = np.transpose(rgb, (2,0,1))
             #rgb = np.expand_dims(rgb, axis = 0)
             #print(type(rgb))
             #rgb = torch.from_numpy(rgb)
             rgb = self.to_tensor(rgb)
             rgb = rgb.unsqueeze_(0)
             rgb = rgb.cuda()
             predict = self.model(rgb).squeeze()
             predict = reverse_one_hot(predict)
             predict = np.array(predict)
             np.save('./predict', predict)
             self.label_pub.publish(
                 self.bridge.cv2_to_imgmsg(predict, '32SC1'))
             print('ss')
     except CvBridgeError as e:
         print(e)
def predict_on_image(model, args, data, label_file, img_info):
    # read csv label path
    label_info = get_label_info(args.csv_path)

    # pre-processing on image
    label = Image.open(label_file)
    label = np.array(label)
    label = one_hot_it_v11_dice(label, label_info).astype(np.uint8)
    label = np.transpose(label, [2, 0, 1]).astype(np.float32)
    label = label.squeeze()
    label = np.argmax(label, axis=0)

    image = cv2.imread(data, -1)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    resize = iaa.Scale({'height': args.crop_height, 'width': args.crop_width})
    resize_det = resize.to_deterministic()
    image = resize_det.augment_image(image)
    image = Image.fromarray(image).convert('RGB')
    image = transforms.ToTensor()(image)
    image = transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))(image).unsqueeze(0)
    # predict
    model.eval()
    predict = model(image).squeeze()
    # 512 * 512
    predict = reverse_one_hot(predict)

    predict_ = colour_code_segmentation(np.array(predict), label_info)
    predict_ = cv2.resize(np.uint8(predict_), (512, 512))
    cv2.imwrite('res/pred_' + 'img_info' + '.png',
                cv2.cvtColor(np.uint8(predict_), cv2.COLOR_RGB2BGR))
    diff = plot_diff(np.array(predict), label)
    cv2.imwrite('res/diff_' + 'img_info' + '.png',
                cv2.cvtColor(np.uint8(diff), cv2.COLOR_RGB2BGR))
Example #11
0
def eval(model, dataloader, args, csv_path):
    print('start test!')
    with torch.no_grad():
        model.eval()
        precision_record = []
        tq = tqdm.tqdm(total=len(dataloader) * args.batch_size)
        tq.set_description('test')
        hist = np.zeros((args.num_classes, args.num_classes))
        for i, (data, label) in enumerate(dataloader):
            tq.update(args.batch_size)
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            predict = model(data).squeeze()
            predict = reverse_one_hot(predict)
            predict = np.array(predict)
            # predict = colour_code_segmentation(np.array(predict), label_info)

            label = label.squeeze()
            if args.loss == 'dice':
                label = reverse_one_hot(label)
            label = np.array(label)
            # label = colour_code_segmentation(np.array(label), label_info)
            #saving some images
            if args.save_images_path is not None and i < 40:
                current_image = transforms.functional.to_pil_image(data[0])
                current_label = Image.fromarray(colorize_label(label))
                current_predi = Image.fromarray(colorize_label(predict))
                current_image.save(args.save_images_path + f"/image{i}.jpg")
                current_label.save(args.save_images_path + f"/label{i}.jpeg")
                current_predi.save(args.save_images_path +
                                   f"/prediction{i}.jpeg")

            precision = compute_global_accuracy(predict, label)
            hist += fast_hist(label.flatten(), predict.flatten(),
                              args.num_classes)
            precision_record.append(precision)
        precision = np.mean(precision_record)
        miou_list = per_class_iu(hist)[:-1]
        miou_dict, miou = cal_miou(miou_list, csv_path)
        print('IoU for each class:')
        for key in miou_dict:
            print('{}:{},'.format(key, miou_dict[key]))
        tq.close()
        print('precision for test: %.3f' % precision)
        print('mIoU for validation: %.3f' % miou)
        return precision
Example #12
0
def val(args, model, dataloader):
    print('start val!')
    # label_info = get_label_info(csv_path)
    with torch.no_grad():
        model.eval()
        precision_record = []
        hist = np.zeros((args.num_classes, args.num_classes))
        for i, (data, label) in enumerate(dataloader):
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()

            # get RGB predict image
            predict = model(data).squeeze()
            predict = reverse_one_hot(predict)
            predict = np.array(predict.cpu())

            # get RGB label image
            label = label.squeeze()
            if args.loss == 'dice':
                label = reverse_one_hot(label)
            label = np.array(label.cpu())

            # compute per pixel accuracy

            precision = compute_global_accuracy(predict, label)
            hist += fast_hist(label.flatten(), predict.flatten(),
                              args.num_classes)

            # there is no need to transform the one-hot array to visual RGB array
            # predict = colour_code_segmentation(np.array(predict), label_info)
            # label = colour_code_segmentation(np.array(label), label_info)
            precision_record.append(precision)
        precision = np.mean(precision_record)
        # miou = np.mean(per_class_iu(hist))
        miou_list = per_class_iu(hist)[:-1]
        # miou_dict, miou = cal_miou(miou_list, csv_path)
        miou = np.mean(miou_list)
        print('precision per pixel for test: %.3f' % precision)
        print('mIoU for validation: %.3f' % miou)
        # miou_str = ''
        # for key in miou_dict:
        #     miou_str += '{}:{},\n'.format(key, miou_dict[key])
        # print('mIoU for each class:')
        # print(miou_str)
        return precision, miou
Example #13
0
def predict_on_RGB(image):  # nd convenient both for img and video
    # pre-processing on image
    image = resize_img(image)
    image = transforms.ToTensor()(image).float().unsqueeze(0)

    # predict
    model.eval()
    predict = model(image).squeeze()
    predict = reverse_one_hot(predict)
    predict = colour_code_segmentation(np.array(predict), label_info)  # RGB
    predict = predict.astype(np.uint8)

    return predict
Example #14
0
def eval(model, dataloader, args):
    print('start test!')
    with torch.no_grad():
        model.eval()
        precision_record = []
        total_miou = []
        for i, (data, label) in enumerate(dataloader):
            #tq.update(args.batch_size)
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()

            t1 = time.time()
            predict = model(data)
            print 'time:', time.time() - t1, '\n'
            predict = predict.squeeze()
            #print predict.size()
            predict = reverse_one_hot(predict).unsqueeze(0)
            #print predict.size()
            # predict = colour_code_segmentation(np.array(predict), label_info)
            #print predict.data
            label = label.squeeze().unsqueeze(0)
            #print label.data
            # label = reverse_one_hot(label)
            # label = colour_code_segmentation(np.array(label), label_info)

            metric = IoU(num_classes=2, ignore_index=None)
            metric.reset()
            #print predict.size(),label.size()
            metric.add(predict.data, label.data)
            iou, miou = metric.value()
            #print iou, miou

            precision = compute_global_accuracy(predict, label)
            print('precision: %.3f' % precision, 'mIOU: %.3f' % miou)
            # predict = predict.cpu().data.numpy()
            # data = data.cpu().data.numpy()
            # label = label.cpu().data.numpy()
            # fig,ax=plt.subplots(1,2)
            #ax[0].imshow(data)
            # ax[0].imshow(predict)
            # ax[1].imshow(label)
            # plt.show()
            total_miou.append(miou)
            precision_record.append(precision)
        precision = np.mean(precision_record)
        total_miou = np.mean(total_miou)
        #tq.close()
        print('precision for test: %.3f' % precision)
        print('total mIOU for test: %.3f' % total_miou)
        return precision, total_miou
Example #15
0
    def predict_on_image(self, image):
        # transform image to tensor
        image = self.transform(image[:, :, ::-1]).to(self.device)

        # prediction map
        predict = self.model(image.unsqueeze(0)).squeeze()

        # encode to class index
        predict = reverse_one_hot(predict).cpu().numpy()

        # encode to color code
        predict = colour_code_segmentation(predict,
                                           self.label_info).astype(np.uint8)

        # get bbox output
        predict, bboxes, num_people = self.bbox_output(predict)
        return predict, bboxes, num_people
def predict_on_image(model, args):
    # pre-processing on image
    image = cv2.imread(args.data, -1)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    resize = iaa.Scale({'height': args.crop_height, 'width': args.crop_width})
    resize_det = resize.to_deterministic()
    image = resize_det.augment_image(image)
    image = Image.fromarray(image).convert('RGB')
    image = transforms.ToTensor()(image).unsqueeze(0)

    # read csv label path
    label_info = get_label_info(args.csv_path)
    # predict
    model.eval()
    predict = model(image).squeeze()
    predict = reverse_one_hot(predict)
    predict = colour_code_segmentation(np.array(predict), label_info)
    predict = cv2.resize(np.uint8(predict), (960, 720))
    cv2.imwrite(args.save_path,
                cv2.cvtColor(np.uint8(predict), cv2.COLOR_RGB2BGR))
Example #17
0
def seg_predict(image):
    global bise_model
    try:
        with torch.no_grad():
            bise_model.eval()
            h,w,_ = image.shape
            to_tensor = transforms.Compose([
                    transforms.ToTensor(),
                    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
                ])

            image = to_tensor(image)
            image = image.unsqueeze_(0)
            image = image.cuda()
            predict = bise_model(image).squeeze()
            predict = reverse_one_hot(predict)
            predict = np.array(predict)
            predict = np.resize(predict,[h,w])
            print(np.unique(predict))
    except CvBridgeError as e:
        print(e)
Example #18
0
def predict_on_RGBD(image, depth):  # nd convenient both for img and video
    # pre-processing on image
    image = resize_img(image).convert('RGB')
    image = transforms.ToTensor()(image).float().unsqueeze(0)

    depth = resize_depth(depth)
    depth = np.array(depth)
    depth = depth[:, :, np.newaxis]
    depth = depth / 255
    depth = transforms.ToTensor()(depth).float().unsqueeze(0)

    rgbd = torch.cat((image, depth), 1)

    # predict
    model.eval()
    predict = model(rgbd).squeeze()
    predict = reverse_one_hot(predict)
    predict = colour_code_segmentation(np.array(predict), label_info)
    predict = np.uint8(predict)

    return cv2.cvtColor(np.uint8(predict), cv2.COLOR_RGB2BGR)
Example #19
0
def predict_on_image(model, args, image):
    '''
        run inference and return the resultant image
    '''
    # pre-processing on image
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    resize = iaa.Scale({'height': args.crop_height, 'width': args.crop_width})
    resize_det = resize.to_deterministic()
    image = resize_det.augment_image(image)
    image = Image.fromarray(image).convert('RGB')
    image = transforms.ToTensor()(image)
    image = transforms.Normalize((0.485, 0.456, 0.406),
                                 (0.229, 0.224, 0.225))(image).unsqueeze(0)
    # read csv label path
    label_info = get_label_info(args.csv_path)
    # predict
    model.eval()
    predict = model(image).squeeze()
    predict = reverse_one_hot(predict)
    # predict = colour_code_segmentation(np.array(predict), label_info)
    predict = colour_code_segmentation(np.array(predict.cpu()), label_info)
    predict = cv2.resize(np.uint8(predict), (960, 720))
    # cv2.imwrite(args.save_path, cv2.cvtColor(np.uint8(predict), cv2.COLOR_RGB2BGR))
    return predict
Example #20
0
def main(params):

    parser = argparse.ArgumentParser()

    parser.add_argument('--save_model_path',
                        type=str,
                        default=None,
                        help='path to save model')
    parser.add_argument('--num_classes',
                        type=int,
                        default=32,
                        help='num of object classes (with void)')
    parser.add_argument(
        '--context_path',
        type=str,
        default="resnet18",
        help='The context path model you are using, resnet18, resnet101.')
    args = parser.parse_args(params)

    model = BiSeNet(args.num_classes, args.context_path)
    model.load_state_dict(
        torch.load(os.path.join(args.save_model_path, 'best_dice_loss.pth')))
    model.eval()

    img = Image.open('./CamVid/test/Seq05VD_f00660.png')
    transform = transforms.Compose([
        transforms.Resize([720, 960]),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
    ])
    img = transform(img).unsqueeze(dim=0)

    imresult = np.zeros([img.shape[2], img.shape[3], 3], dtype=np.uint8)
    with torch.no_grad():
        predict = model(img).squeeze()
        predict = reverse_one_hot(predict)
        predict = np.array(predict)
        imresult[:, :][predict == 0] = [255, 51, 255]
        imresult[:, :][predict == 1] = [255, 0, 0]
        imresult[:, :][predict == 2] = [0, 255, 0]
        imresult[:, :][predict == 3] = [0, 0, 255]
        imresult[:, :][predict == 4] = [255, 255, 0]
        imresult[:, :][predict == 5] = [255, 0, 255]
        imresult[:, :][predict == 6] = [0, 255, 255]
        imresult[:, :][predict == 7] = [10, 200, 128]
        imresult[:, :][predict == 8] = [125, 18, 78]
        imresult[:, :][predict == 9] = [205, 128, 8]
        imresult[:, :][predict == 10] = [144, 208, 18]
        imresult[:, :][predict == 11] = [5, 88, 198]
    cv2.imwrite('result.png', imresult)
    print('inference done')

    summary(model, (3, 720, 960), 1, "cpu")

    macs, params = profile(model, inputs=(img, ))
    print('macs', macs)
    print('params', params)

    log = open('log.txt', 'w')
    EXPORTONNXNAME = 'nit-bisenet.onnx'
    try:
        torch.onnx.export(
            model,
            img,
            EXPORTONNXNAME,
            export_params=True,
            do_constant_folding=True,
            input_names=['data'],
            # output_names = ['output']
            output_names=['output'])
    except Exception:
        traceback.print_exc(file=log)

    print('export done')
    def val_out(self,
                sess,
                val_init,
                threshold=0.5,
                output_dir="val",
                epoch=0):
        print("validation starts.")
        save_dir = output_dir + "/%d" % (epoch)
        if not os.path.exists(save_dir):
            os.mkdir(save_dir)
        target = open(save_dir + "/val_scores.csv", 'w')
        target.write(
            "val_name, avg_accuracy, precision, recall, f1 score, mean iou, %s\n"
            % (self.name_string))

        sess.run(val_init)
        #for ind in range(self.num_val):
        scores_list = []
        class_scores_list = []
        precision_list = []
        recall_list = []
        f1_list = []
        iou_list = []
        try:
            while True:
                img, ann, output_image = sess.run(
                    [self.img, self.mask, self.logits])
                img = img[0, :, :, :] * 255

                ann = np.array(ann[0, :, :, :])
                ann = reverse_one_hot(ann)

                path, size = sess.run([self.path, self.size])
                size = (size[0][0], size[0][1])

                output_single_image = np.array(output_image)
                output_single_image = np.array(output_single_image[0, :, :, :])
                output_image = reverse_one_hot(output_single_image)
                out_vis_image = colour_code_segmentation(
                    output_image, self.label_values)

                accuracy, class_accuracies, prec, rec, f1, iou = evaluate_segmentation(
                    pred=output_image, label=ann, num_classes=self.num_classes)

                dir = path[0].decode('ascii')
                file_name = filepath_to_name(dir)

                target.write("%s, %f, %f, %f, %f, %f" %
                             (file_name, accuracy, prec, rec, f1, iou))
                for item in class_accuracies:
                    target.write(", %f" % (item))
                target.write("\n")

                mask = colour_code_segmentation(ann, self.label_values)

                mask = cv2.cvtColor(np.uint8(mask), cv2.COLOR_RGB2BGR)
                out_vis_image = cv2.cvtColor(np.uint8(out_vis_image),
                                             cv2.COLOR_RGB2BGR)
                mask, ori_out_vis = resizeImage(mask, out_vis_image, size)

                out_vis_image = cv2.resize(ori_out_vis[:, :, 1],
                                           size,
                                           interpolation=cv2.INTER_NEAREST)
                out_vis_image[out_vis_image < threshold * 255] = 0
                out_vis_image[out_vis_image >= threshold * 255] = 255

                save_ori_img = cv2.cvtColor(np.uint8(img), cv2.COLOR_RGB2BGR)
                save_ori_img = cv2.resize(save_ori_img,
                                          size,
                                          interpolation=cv2.INTER_NEAREST)
                transparent_image = np.append(np.array(save_ori_img)[:, :,
                                                                     0:3],
                                              out_vis_image[:, :, None],
                                              axis=-1)
                # transparent_image = Image.fromarray(transparent_image)

                cv2.imwrite(save_dir + "/%s_img.jpg" % (file_name),
                            save_ori_img)
                cv2.imwrite(save_dir + "/%s_ann.png" % (file_name), mask)
                cv2.imwrite(save_dir + "/%s_ori_pred.png" % (file_name),
                            ori_out_vis)
                cv2.imwrite(save_dir + "/%s_filter_pred.png" % (file_name),
                            out_vis_image)
                cv2.imwrite(save_dir + "/%s_mat.png" % (file_name),
                            transparent_image)

                scores_list.append(accuracy)
                class_scores_list.append(class_accuracies)
                precision_list.append(prec)
                recall_list.append(rec)
                f1_list.append(f1)
                iou_list.append(iou)
        except tf.errors.OutOfRangeError:
            avg_score = np.mean(scores_list)
            class_avg_scores = np.mean(class_scores_list, axis=0)
            avg_precision = np.mean(precision_list)
            avg_recall = np.mean(recall_list)
            avg_f1 = np.mean(f1_list)
            avg_iou = np.mean(iou_list)

            print("\nAverage validation accuracy for epoch # %04d = %f" %
                  (epoch, avg_score))
            print("Average per class validation accuracies for epoch # %04d:" %
                  (epoch))
            for index, item in enumerate(class_avg_scores):
                print("%s = %f" % (self.name_list[index], item))
            print("Validation precision = ", avg_precision)
            print("Validation recall = ", avg_recall)
            print("Validation F1 score = ", avg_f1)
            print("Validation IoU score = ", avg_iou)
Example #22
0
def val(args, model, dataloader, data_name):
    print('start val!')
    # label_info = get_label_info(csv_path)
    total_cks, total_f1 = 0.0, 0.0
    total_pred = np.array([0])
    total_label = np.array([0])
    length = len(dataloader)
    with torch.no_grad():
        model.eval()
        precision_record = []
        hist = np.zeros((args.num_classes, args.num_classes))
        for i, (data, label) in enumerate(dataloader):
            # print('label size: ', label.size())
            # print('data size: ', data.size())
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()

            # get RGB predict image
            # print('label_cuda size: ', label.size())
            # print('data_cuda size: ', data.size())
            predict = model(data).squeeze()
            # print('predict size: ', predict.size())

            predict = reverse_one_hot(predict)
            predict = np.array(predict)

            # get RGB label image
            label = label.squeeze()
            if args.loss == 'dice':
                label = reverse_one_hot(label)
            label = np.array(label)

            total_pred = np.append(total_pred, predict.flatten())
            total_label = np.append(total_label, label.flatten())
            if (i + 1) % 8 == 0:
                # total_cm += confusion_matrix(total_label[1:], total_pred[1:])
                cks = cohen_kappa_score(total_label[1:], total_pred[1:])
                total_label = np.array([0])
                total_pred = np.array([0])
                total_cks += cks
            # cks = cohen_kappa_score(label.flatten(), predict.flatten())
            # total_cks += cks
            f1 = f1_score(label.flatten(), predict.flatten(), average='macro')
            total_f1 += f1

            precision = compute_global_accuracy(predict, label)
            hist += fast_hist(label.flatten(), predict.flatten(),
                              args.num_classes)

            # there is no need to transform the one-hot array to visual RGB array
            # predict = colour_code_segmentation(np.array(predict), label_info)
            # label = colour_code_segmentation(np.array(label), label_info)
            precision_record.append(precision)
        precision = np.mean(precision_record)
        # miou = np.mean(per_class_iu(hist))
        miou_list = per_class_iu(hist)[:-1]
        # miou_dict, miou = cal_miou(miou_list, csv_path)
        miou = np.mean(miou_list)
        # print('precision per pixel for test: %.3f' % precision)
        print('oa for %s: %.3f' % (data_name, precision))
        # print('mIoU for validation: %.3f' % miou)
        print('mIoU for %s: %.3f' % (data_name, miou))
        cm, cks, cr = compute_cm_cks_cr(predict, label)
        total_f1 /= length
        total_cks = total_cks / (length // 8)
        # print('cm:\n', cm)
        print('kappa for %s: %.4f' % (data_name, total_cks))
        print('f1 for {}:\n'.format(data_name), total_f1)
        # miou_str = ''
        # for key in miou_dict:
        #     miou_str += '{}:{},\n'.format(key, miou_dict[key])
        # print('mIoU for each class:')
        # print(miou_str)
        return precision, miou, cm, total_cks, total_f1
def eval(model, dataloader, args, csv_path):
    print('start test!')
    with torch.no_grad():
        total_pred = np.array([0])
        total_label = np.array([0])
        total_cm = np.zeros((6, 6))
        model.eval()
        precision_record = []
        tq = tqdm.tqdm(total=len(dataloader) * args.batch_size)
        tq.set_description('test')
        hist = np.zeros((args.num_classes, args.num_classes))
        total_time = 0
        total_cks, total_f1 = 0.0, 0.0
        length = len(dataloader)
        print('length: %d' % length)
        for i, (data, label) in enumerate(dataloader):
            tq.update(args.batch_size)
            if torch.cuda.is_available() and args.use_gpu:
                data = data.cuda()
                label = label.cuda()
            start = time.clock()
            predict = model(data).squeeze()
            end = time.clock()
            # 转为类别矩阵
            predict = reverse_one_hot(predict)
            predict = np.array(predict)
            # predict = colour_code_segmentation(np.array(predict), label_info)
            # end = time.clock()
            # 测试花费时间
            total_time += (end - start)
            label = label.squeeze()
            # 转换为类别矩阵
            if args.loss == 'dice':
                label = reverse_one_hot(label)
            label = np.array(label)

            # 计算cm
            # total_pred = np.append(total_pred, predict.flatten())
            # total_label = np.append(total_label, label.flatten())
            # if (i+1) % 8 == 0:
            #     total_cm += confusion_matrix(total_label[1:], total_pred[1:])
            #     total_label = np.array([0])
            #     total_pred = np.array([0])

            # 计算kappa,总的算平均
            cks = cohen_kappa_score(label.flatten(), predict.flatten())
            total_cks += cks
            f1 = f1_score(label.flatten(), predict.flatten(), average='macro')
            total_f1 += f1
            # label = colour_code_segmentation(np.array(label), label_info)
            # 计算oa
            precision = compute_global_accuracy(predict, label)
            hist += fast_hist(label.flatten(), predict.flatten(),
                              args.num_classes)
            # 记录总的精度
            precision_record.append(precision)
        # 保存cm
        # np.savetxt('cm.txt', total_cm)
        precision = np.mean(precision_record)
        miou_list = per_class_iu(hist)
        miou_dict, miou = cal_miou(miou_list, csv_path)
        print('IoU for each class:')
        for key in miou_dict:
            print('{}:{},'.format(key, miou_dict[key]))
        tq.close()
        print('oa for test: %.3f' % precision)
        print('mIoU for test: %.3f' % miou)

        # 计算cm, kappa, cr //作废
        cm, cks, cr = compute_cm_cks_cr(predict, label)
        # print('cm for test:\n', cm)
        total_cks /= length
        print('kappa for test: %.4f' % total_cks)
        total_f1 /= length
        print('f1 for test: %.4f' % total_f1)
        fps = length / total_time
        print('fps: %.2f' % fps)
        return precision, cm, total_cks, cr
Example #24
0
def main():

    sess = setting.sess
    config = setting.config

    if config["phase"] == "prepare_data":
        tfRecGen = TFRecordGenerator(
            config["data"]["directory"] + "/train",
            config["data"]["directory"] + "/train_labels")
        data_file = os.path.join(config["data"]["tfrecord"]["directory"],
                                 config["data"]["tfrecord"]["train_filename"])

        if (not os.path.isfile(data_file)) or (os.stat(data_file).st_size
                                               == 0):
            tfRecGen.write_TFRecord(
                config["data"]["tfrecord"]["train_filename"], flag="training")
        else:
            utils.print_log("training tfrecord already exists", "INFO")
            tfRecGen.read_TFRecord(
                config["data"]["tfrecord"]["train_filename"])

        tfRecGen = TFRecordGenerator(
            config["data"]["directory"] + "/val",
            config["data"]["directory"] + "/val_labels")
        data_file = os.path.join(config["data"]["tfrecord"]["directory"],
                                 config["data"]["tfrecord"]["val_filename"])

        if (not os.path.isfile(data_file)) or (os.stat(data_file).st_size
                                               == 0):
            tfRecGen.write_TFRecord(config["data"]["tfrecord"]["val_filename"],
                                    flag="validation")
        else:
            utils.print_log("validation tfrecord already exists", "INFO")
            tfRecGen.read_TFRecord(config["data"]["tfrecord"]["val_filename"])

    elif config["phase"] == "train":
        ckpt_dir = strftime("run_%Y_%m_%d_%H:%M:%S", localtime())
        suffix_name = config["data"]["tfrecord"]["train_filename"]
        suffix_name = suffix_name[suffix_name.find("_") +
                                  1:suffix_name.find(".")] + "_original_split"
        ckpt_dir = ckpt_dir + "_" + suffix_name

        if not os.path.isdir(
                config["training_setting"]["checkpoints"]["save_directory"] +
                "/" + config["model"] + "/" + ckpt_dir):
            os.makedirs(
                config["training_setting"]["checkpoints"]["save_directory"] +
                "/" + config["model"] + "/" + ckpt_dir)

        if not os.path.isdir(
                config["training_setting"]["checkpoints"]["save_directory"] +
                "/" + config["model"] + "/" + ckpt_dir + "/best_checkpoint"):
            os.makedirs(
                config["training_setting"]["checkpoints"]["save_directory"] +
                "/" + config["model"] + "/" + ckpt_dir + "/best_checkpoint")

        summary_events_file = config["training_setting"]["checkpoints"][
            "save_directory"] + "/" + config["model"] + "/" + ckpt_dir
        log_file = config["training_setting"]["checkpoints"][
            "save_directory"] + "/" + config[
                "model"] + "/" + ckpt_dir + "/training.log"

        utils.print_log(
            (config["training_setting"]["logging"]["training_note"]).upper(),
            "BOLD", log_file)

        epochs = int(config["training_setting"]["epochs"])

        model = model_hub.get_model(name=config["model"])

        next_data = namedtuple("next_data", "input label")
        with tf.name_scope("input") as scope:

            # is_training = tf.placeholder(dtype=tf.bool, shape=[], name="is_training")
            phase = tf.placeholder(tf.string, shape=[], name="phase")

            # train_dataset, val_dataset, _ = dataset_Feature_Raw()
            train_dataset, val_dataset = dataset_Feature_Raw()

            batch_size = int(config["training_setting"]["batch_size"])

            if batch_size == 1:
                utils.print_log(
                    "detected batch size: {}\noriginal image resolution will be used for training\n"
                    .upper().format(batch_size), "BOLD", log_file)
                train_dataset = train_dataset.map(parse_map)
                val_dataset = val_dataset.map(parse_map)
            else:
                utils.print_log(
                    "found batch size: {}\noptimal image resolution calculated for each batch is: {}\n"
                    .upper().format(batch_size,
                                    setting.data_size), "BOLD", log_file)
                train_dataset = train_dataset.map(parse_map_with_resize)
                val_dataset = val_dataset.map(parse_map_with_resize)

            train_dataset = train_dataset.batch(batch_size)
            train_dataset = train_dataset.shuffle(buffer_size=10000)
            train_dataset = train_dataset.repeat()
            val_dataset = val_dataset.batch(batch_size)
            handle = tf.placeholder(tf.string, shape=[], name="handle")
            iter = tf.data.Iterator.from_string_handle(handle, train_dataset.output_types,\
                       train_dataset.output_shapes)

            next = iter.get_next()
            current_data = next_data(next[0], next[1])

            is_training = tf.case(
                {
                    tf.equal(phase, tf.constant("train")): lambda: True,
                    tf.equal(phase, tf.constant("val")): lambda: False,
                    tf.equal(phase, tf.constant("test")): lambda: False
                },
                exclusive=True)
            train_iterator = train_dataset.make_one_shot_iterator()
            val_iterator = val_dataset.make_initializable_iterator()

        # with tf.name_scope("output") as scope:
        # logits = model(current_data.input, is_training)
        logits = model._setup(current_data.input, is_training)

        with tf.name_scope("performance_metrics") as scope:
            m_labels = tf.argmax(current_data.label, axis=-1)
            m_logits = tf.nn.softmax(logits)
            m_logits = tf.argmax(m_logits, axis=-1)

            false_positive, false_positive_update = tf.metrics.false_positives(
                current_data.label, logits, name="false_positive")
            false_negative, false_negative_update = tf.metrics.false_negatives(
                current_data.label, logits, name="false_negative")
            true_positive, true_positive_update = tf.metrics.true_positives(
                current_data.label, logits, name="true_positive")
            true_negative, true_negative_update = tf.metrics.true_negatives(
                current_data.label, logits, name="true_negative")
            recall, recall_update = tf.metrics.recall(labels=m_labels,
                                                      predictions=m_logits,
                                                      name="recall")
            mean_iou, mean_iou_update = tf.metrics.mean_iou(
                labels=m_labels,
                predictions=m_logits,
                num_classes=setting.num_classes,
                name="mean_iou")
            accuracy, accuracy_update = tf.metrics.accuracy(
                labels=m_labels, predictions=m_logits, name="accuracy")

            running_vars_false_positive = tf.get_collection(
                tf.GraphKeys.LOCAL_VARIABLES, scope="false_positive")
            running_vars_false_negative = tf.get_collection(
                tf.GraphKeys.LOCAL_VARIABLES, scope="false_negative")
            running_vars_true_positive = tf.get_collection(
                tf.GraphKeys.LOCAL_VARIABLES, scope="true_positive")
            running_vars_true_negative = tf.get_collection(
                tf.GraphKeys.LOCAL_VARIABLES, scope="true_negative")
            running_vars_recall = tf.get_collection(
                tf.GraphKeys.LOCAL_VARIABLES, scope="recall")
            running_vars_mean_iou = tf.get_collection(
                tf.GraphKeys.LOCAL_VARIABLES, scope="mean_iou")
            running_vars_accuracy = tf.get_collection(
                tf.GraphKeys.LOCAL_VARIABLES, scope="accuracy")

            running_vars_initializer_false_positive = tf.variables_initializer(
                var_list=running_vars_false_positive)
            running_vars_initializer_false_negative = tf.variables_initializer(
                var_list=running_vars_false_negative)
            running_vars_initializer_true_positive = tf.variables_initializer(
                var_list=running_vars_true_positive)
            running_vars_initializer_true_negative = tf.variables_initializer(
                var_list=running_vars_true_negative)
            running_vars_initializer_recall = tf.variables_initializer(
                var_list=running_vars_recall)
            running_vars_initializer_mean_iou = tf.variables_initializer(
                var_list=running_vars_mean_iou)
            running_vars_initializer_accuracy = tf.variables_initializer(
                var_list=running_vars_accuracy)

            false_positive_rate = tf.div(false_positive,
                                         tf.add(false_positive, true_negative))
            false_negative_rate = tf.div(false_negative,
                                         tf.add(true_positive, false_negative))
            precision = tf.div(true_positive,
                               tf.add(true_positive, false_positive))
            f1_score = tf.div(tf.multiply(2.0, tf.multiply(precision, recall)),
                              tf.add(precision, recall))

            tf.summary.scalar("mean iou", mean_iou)
            tf.summary.scalar("accuracy", accuracy)
            tf.summary.scalar("false positive rate", false_positive_rate)
            tf.summary.scalar("false negative rate", false_negative_rate)
            tf.summary.scalar("precision", precision)
            tf.summary.scalar("f1 score", f1_score)

        with tf.name_scope("softmax_operation") as scope:
            soft_inp = tf.placeholder(name="softmax_input",
                                      shape=[None, None, None],
                                      dtype=tf.float32)
            soft_out = tf.nn.softmax(soft_inp)

        with tf.name_scope("loss") as scope:
            if config["training_setting"]["loss"] == "cross_entropy":
                no_gradient_label = tf.stop_gradient(current_data.label)
                if config["training_setting"]["class_balancing"] == "none":
                    utils.print_log("No class balancing selected", "BOLD",
                                    log_file)
                    losses = tf.nn.softmax_cross_entropy_with_logits_v2(
                        logits=logits, labels=no_gradient_label)
                else:
                    utils.print_log("Class balancing selected", "BOLD",
                                    log_file)
                    weights = current_data.label * class_weights
                    weights = tf.reduce_sum(weights, -1)
                    losses = tf.losses.softmax_cross_entropy(
                        onehot_labels=current_data.label,
                        logits=logits,
                        weights=weights)
            elif config["training_setting"]["loss"] == "lovasz":
                if not config["training_setting"]["class_balancing"] == "none":
                    utils.print_log("No class balancing selected", "BOLD",
                                    log_file)
                    losses = helpers.lovasz_softmax(probas=logits,
                                                    labels=labels)
                else:
                    utils.print_log("Class balancing selected", "BOLD",
                                    log_file)
                    weights = current_data.label * class_weights
                    weights = tf.reduce_sum(weights, -1)
                    losses = helpers.lovasz_softmax(probas=logits,
                                                    labels=labels)

            loss = tf.reduce_mean(losses)
            tf.summary.scalar("loss", loss)

        with tf.name_scope("optimization") as scope:
            if config["training_setting"]["analyse_lr"]:
                lr = tf.placeholder(name="analyse_lr",
                                    dtype=tf.float32,
                                    shape=[])
                lr_mult = float(config["training_setting"]["lr_mult"])
                lr_mult_bias = float(
                    config["training_setting"]["lr_mult_bias"])
                optimizer = tf.train.AdamOptimizer(lr)
                # optimizer_except_bias = tf.train.AdamOptimizer(lr_mult*lr)
                # optimizer_bias = tf.train.AdamOptimizer(lr_mult_bias*lr)
            else:
                lr = float(config["training_setting"]["learning_rate"])
                lr_mult = float(config["training_setting"]["lr_mult"])
                lr_mult_bias = float(
                    config["training_setting"]["lr_mult_bias"])

                optimizer = tf.train.AdamOptimizer(lr)
                # optimizer_except_bias = tf.train.AdamOptimizer(lr_mult*lr)
                # optimizer_bias = tf.train.AdamOptimizer(lr_mult_bias*lr)

            all_var_list = [var for var in tf.trainable_variables()]
            bias_var_list = [
                var for var in tf.trainable_variables() if "bias" in var.name
            ]
            except_bias_list = list(
                set(all_var_list).difference(bias_var_list))

            update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
            with tf.control_dependencies(update_ops):
                train_op = optimizer.minimize(
                    loss, var_list=[var for var in tf.trainable_variables()])
                # optimize_except_bias_train_op = optimizer_except_bias.minimize(loss, var_list=except_bias_list)
                # optimize_bias_train_op = optimizer_bias.minimize(loss, var_list=bias_var_list)
                # train_op = tf.group(optimize_except_bias_train_op, optimize_bias_train_op)

            # if config["training_setting"]["clip_norm"]:
            # 	print("clip gradients will be used")
            # 	print(float(config["training_setting"]["clip_norm"]))
            # 	clip_norm = float(config["training_setting"]["clip_norm"])
            # 	optimizer = tf_utils.OptimizerCustomOperation(optimizer)
            # 	with tf.control_dependencies(update_ops):
            # 		train_op = optimizer.minimize(loss, var_list=[var for var in tf.trainable_variables()])
            # else:
            # 	print("clip gradients will not be used")
            # 	with tf.control_dependencies(update_ops):
            # 		train_op = optimizer.minimize(loss, var_list=[var for var in tf.trainable_variables()])

        with tf.name_scope("aggregate_summaries") as scope:
            merge_summary = tf.summary.merge_all()

        if not config["training_setting"]["analyse_lr"]:
            summary_writer = tf.summary.FileWriter(
                summary_events_file, graph=tf.get_default_graph())

        train_handle = sess.run(train_iterator.string_handle())
        val_handle = sess.run(val_iterator.string_handle())

        saver = tf.train.Saver(max_to_keep=1)
        if not config["training_setting"]["analyse_lr"]:
            meta_graph_def = tf.train.export_meta_graph(
                filename="./model_graphs/adapnet.meta")

        ckpt_prefix = config["training_setting"]["checkpoints"][
            "save_directory"] + "/" + config[
                "model"] + "/" + ckpt_dir + "/" + config["training_setting"][
                    "checkpoints"]["prefix"]
        best_ckpt_dir = config["training_setting"]["checkpoints"][
            "save_directory"] + "/" + config[
                "model"] + "/" + ckpt_dir + "/best_checkpoint/" + config[
                    "training_setting"]["checkpoints"]["prefix"]

        utils.print_log("Training in progress . . . . .", "INFO", log_file)
        utils.print_log("Training start time:", "INFO", log_file)
        utils.print_log(strftime("%Y_%m_%d_%H:%M:%S", localtime()), "INFO",
                        log_file)

        global_counter = 0
        cnt = 0
        best_miou = 0.0
        early_stopping_counter = 0
        total_steps = epochs * int(train_count / batch_size)

        utils.print_log(
            "Total steps for training: {}".format(str(total_steps)), "INFO",
            log_file)

        fig, ax = plt.subplots()

        # lr_range = [0.01, 0.02, 0.03, 0.04, 0.05, 0.06]
        # lr_range = [ 0.001, 0.003, 0.01, 0.02, 0.03, 0.04, 0.05, 0.06]
        lr_range = [0.06, 0.05, 0.04, 0.03, 0.02, 0.01, 0.003, 0.001]
        loss_value = []

        if config["training_setting"]["analyse_lr"]:
            print("tracking loss versus learning rate curve....")
            for current_lr in lr_range:
                print("running for learning rate: {}".format(current_lr))
                print("")
                sess.run(tf.global_variables_initializer())
                i = 0
                for index in range(epochs):
                    try:
                        for train_index in range(int(train_count /
                                                     batch_size)):
                            _, current_loss = sess.run(
                                [train_op, loss],
                                feed_dict={
                                    handle: train_handle,
                                    phase: "train",
                                    lr: current_lr
                                })
                            print("step: {} loss: {}".format(i, current_loss))
                            # time.sleep(1)

                            i += 1
                            # print(i)

                    except tf.errors.OutOfRangeError as e:
                        pass

                loss_value.append(current_loss)

            ax.plot(lr_range, loss_value)
            ax.set(xlabel='learning rate',
                   ylabel='loss',
                   title='Track best learning rate')
            ax.grid()
            fig.savefig("learning_rate_vs_loss_curve.png")
            with open("lr_vs_loss.txt", "w+") as lr_loss:
                lr_loss.write("learning rate: " + str(lr_range) + "\n")
                lr_loss.write("loss: " + str(loss_value))

        else:

            sess.run(tf.global_variables_initializer())

            for index in range(epochs):
                st = time.time()
                epoch_nbr = index
                # print(index)
                try:
                    for train_index in range(int(train_count / batch_size)):

                        if (cnt == (total_steps - 1) * batch_size) or (
                            (cnt - batch_size) %
                                int(config["training_setting"]["logging"]
                                    ["display_iteration"]) == 0):
                            _, current_loss, cur_data = sess.run(
                                [train_op, loss, current_data],
                                feed_dict={
                                    handle: train_handle,
                                    phase: "train"
                                })
                            message = "Epoch = %d Processed examples = %d Current_Loss = %.4f Time = %.2f" % (
                                epoch_nbr, (cnt + batch_size), current_loss,
                                time.time() - st)
                            utils.print_log("\n" + message, "TRAINING_STATUS",
                                            log_file)
                            utils.print_log(
                                "Total examples for training remaining: {}".
                                format(
                                    str(total_steps * batch_size -
                                        (cnt + batch_size))), "INFO", log_file)
                            st = time.time()

                        else:
                            sess.run(train_op,
                                     feed_dict={
                                         handle: train_handle,
                                         phase: "train"
                                     })

                        if cnt > 0 and cnt % int(config["training_setting"][
                                "logging"]["evaluation_iteration"]) == 0:

                            if not os.path.isdir(
                                    config["training_setting"]["checkpoints"]
                                ["save_directory"] + "/" + config["model"] +
                                    "/" + ckpt_dir + "/overlayed_images"):
                                os.makedirs(config["training_setting"]
                                            ["checkpoints"]["save_directory"] +
                                            "/" + config["model"] + "/" +
                                            ckpt_dir + "/overlayed_images")

                            if not os.path.isdir(
                                    config["training_setting"]["checkpoints"]
                                ["save_directory"] + "/" + config["model"] +
                                    "/" + ckpt_dir + "/segmentation_images"):
                                os.makedirs(config["training_setting"]
                                            ["checkpoints"]["save_directory"] +
                                            "/" + config["model"] + "/" +
                                            ckpt_dir + "/segmentation_images")

                            if not os.path.isdir(
                                    config["training_setting"]["checkpoints"]
                                ["save_directory"] + "/" + config["model"] +
                                    "/" + ckpt_dir +
                                    "/confidence_segmentation"):
                                os.makedirs(config["training_setting"]
                                            ["checkpoints"]["save_directory"] +
                                            "/" + config["model"] + "/" +
                                            ckpt_dir +
                                            "/confidence_segmentation")

                            sess.run(val_iterator.initializer)
                            # sess.run(initializers for individual metrics)

                            #for initializing the local variables pertaining to each metric
                            sess.run(tf.local_variables_initializer())
                            for val_index in range(int(val_count /
                                                       batch_size)):
                                # output_logits, val_cur_data, merged_summaries, _, _, _, _, _, _, _= sess.run([logits, current_data, merge_summary, recall_update, mean_iou_update, accuracy_update, false_positive_update, false_negative_update, true_positive_update, true_negative_update], feed_dict={handle:val_handle, is_training:False})
                                output_logits, val_cur_data, merged_summaries, _, _, _, _, _, _, _ = sess.run(
                                    [
                                        logits, current_data, merge_summary,
                                        recall_update, mean_iou_update,
                                        accuracy_update, false_positive_update,
                                        false_negative_update,
                                        true_positive_update,
                                        true_negative_update
                                    ],
                                    feed_dict={
                                        handle: val_handle,
                                        phase: "val"
                                    })
                                # output_logits, val_cur_data, _, _, _, _, _, _, _= sess.run([logits, current_data, recall_update, mean_iou_update, accuracy_update, false_positive_update, false_negative_update, true_positive_update, true_negative_update], feed_dict={handle:val_handle, is_training:False})
                                val_index += 1
                            # calculate the average metrics and display
                            rec, m_iou, acc, fpr, fnr, prec, f1 = sess.run([
                                recall, mean_iou, accuracy,
                                false_positive_rate, false_negative_rate,
                                precision, f1_score
                            ])
                            utils.print_log("\nValidation metrics:",
                                            "VALIDATION_STATUS", log_file)
                            utils.print_log("\t\trecall: {}".format(str(rec)),
                                            "VALIDATION_STATUS", log_file)
                            utils.print_log(
                                "\t\tmean IoU: {}".format(str(m_iou)),
                                "VALIDATION_STATUS", log_file)
                            utils.print_log("\t\tacc: {}".format(str(acc)),
                                            "VALIDATION_STATUS", log_file)
                            utils.print_log(
                                "\t\tfalse positive rate: {}".format(str(fpr)),
                                "VALIDATION_STATUS", log_file)
                            utils.print_log(
                                "\t\tfalse negative rate: {}".format(str(fnr)),
                                "VALIDATION_STATUS", log_file)
                            utils.print_log(
                                "\t\tprecision: {}".format(str(prec)),
                                "VALIDATION_STATUS", log_file)
                            utils.print_log("\t\tf1 score: {}".format(str(f1)),
                                            "VALIDATION_STATUS", log_file)

                            #pick the last image from each batch for visualization
                            # Plot the original prediction as segmentation image
                            out_logits_sample = output_logits[-1, :, :, :]
                            out_image = utils.reverse_one_hot(
                                out_logits_sample)
                            out_vis_image = utils.color_code_segmentation(
                                out_image, setting.label_values)

                            if config["training_setting"][
                                    "save_visualization_image"]:
                                cv2.imwrite(
                                    config["training_setting"]["checkpoints"]
                                    ["save_directory"] + "/" +
                                    config["model"] + "/" + ckpt_dir +
                                    "/segmentation_images/step_" + str(cnt) +
                                    ".png",
                                    cv2.cvtColor(np.uint8(out_vis_image),
                                                 cv2.COLOR_RGB2BGR))

                            #use overlay visualization techniques for better visualization
                            # Plot confidences as red-blue overlay
                            val_image = val_cur_data.input[-1, :, :, :]
                            out_logits_sample_confidence = sess.run(
                                soft_out,
                                feed_dict={soft_inp: out_logits_sample})
                            confidence_out_vis_image = utils.make_confidence_overlay(
                                val_image, out_logits_sample_confidence)

                            if config["training_setting"][
                                    "save_visualization_image"]:
                                scipy.misc.imsave(
                                    config["training_setting"]["checkpoints"]
                                    ["save_directory"] + "/" +
                                    config["model"] + "/" + ckpt_dir +
                                    "/confidence_segmentation/step_" +
                                    str(cnt) + ".png",
                                    confidence_out_vis_image)

                            # Plot the original prediction as segmentation overlay
                            overlayed_im = utils.make_overlay(
                                np.uint8(val_image),
                                cv2.cvtColor(np.uint8(out_vis_image),
                                             cv2.COLOR_RGB2BGR))

                            if config["training_setting"][
                                    "save_visualization_image"]:
                                scipy.misc.imsave(
                                    config["training_setting"]["checkpoints"]
                                    ["save_directory"] + "/" +
                                    config["model"] + "/" + ckpt_dir +
                                    "/overlayed_images/step_" + str(cnt) +
                                    ".png", overlayed_im)

                            if best_miou < m_iou:
                                best_miou = m_iou
                                early_stopping_counter = 0
                                # shutil.rmtree(config["training_setting"]["checkpoints"]["save_directory"]+"/"+config["model"]+"/"+ckpt_dir+"/best_checkpoint")
                                tf.train.Saver().save(
                                    sess,
                                    "{0}_miou_{1:6.4f}_recall_{2:6.4f}_accuracy_{3:6.4f}_fpr_{4:6.4f}_fnr_{5:6.4f}_prec_{6:6.4f}_f1_{7:6.4f}_loss_{8:6.4f}_step_{9}"
                                    .format(best_ckpt_dir, m_iou, rec, acc,
                                            fpr, fnr, prec, f1, current_loss,
                                            cnt),
                                    global_step=global_counter,
                                    write_meta_graph=False)
                            # else:
                            # 	early_stopping_counter+=1

                            summary_writer.add_summary(merged_summaries, cnt)

                        if cnt > 0 and cnt % int(config["training_setting"][
                                "checkpoints"]["save_step"]) == 0:
                            saver.save(sess,
                                       ckpt_prefix,
                                       global_step=global_counter,
                                       write_meta_graph=True)
                            shutil.copyfile(
                                "./config.yaml",
                                config["training_setting"]["checkpoints"]
                                ["save_directory"] + "/" + config["model"] +
                                "/" + ckpt_dir + "/config.yaml")

                        cnt = cnt + batch_size
                        global_counter += 1

                        # if early_stopping_counter==10:
                        # 	print("Early stopping as no improvement observed for last 10 validation steps")
                        # 	if not os.path.isfile(config["training_setting"]["checkpoints"]["save_directory"]+"/"+config["model"]+"/"+ckpt_dir+"/notes.txt"):
                        # 		with open(config["training_setting"]["checkpoints"]["save_directory"]+"/"+config["model"]+"/"+ckpt_dir+"/notes.txt", "w+") as f:
                        # 			f.write(config["training_setting"]["logging"]["training_note"])

                        # 	utils.print_log("Training over......Hope I served your purpose...My Lord", "ERROR", log_file)
                        # 	utils.print_log("Training end time:","INFO",log_file)
                        # 	utils.print_log(strftime("%Y_%m_%d_%H:%M:%S", localtime()),"INFO",log_file)
                        # 	return

                        # break
                    # break
                except tf.errors.OutOfRangeError as e:
                    pass

            summary_writer.close()

            if not os.path.isfile(config["training_setting"]["checkpoints"]
                                  ["save_directory"] + "/" + config["model"] +
                                  "/" + ckpt_dir + "/notes.txt"):
                with open(
                        config["training_setting"]["checkpoints"]
                    ["save_directory"] + "/" + config["model"] + "/" +
                        ckpt_dir + "/notes.txt", "w+") as f:
                    f.write(
                        config["training_setting"]["logging"]["training_note"])

            utils.print_log(
                "Training over......Hope I served your purpose...My Lord",
                "ERROR", log_file)
            utils.print_log("Training end time:", "INFO", log_file)
            utils.print_log(strftime("%Y_%m_%d_%H:%M:%S", localtime()), "INFO",
                            log_file)