Example #1
0
def test(filenameSave, model, dataloader_test, args):
    for step, (images, labels, filename, filenameGt) in enumerate(dataloader_test):
        if (args.cuda):
            images = images.cuda()
            # labels = labels.cuda()

        inputs = Variable(images)
        # targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        # label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))


        # image_transform(label.byte()).save(filenameSave)

        # label_save = ToPILImage()(label_color)
        label_save = label_color.numpy()
        label_save = label_save.transpose(1, 2, 0)
        # label_save.save(filenameSave)
        images = images.cpu().numpy().squeeze(axis=0).transpose(1, 2, 0)
        images = (images*255).astype(np.uint8)

        for i in range(len(filename)):
            fileSave = '../eval/'+ args.savedir + "/" + filename[i].split("leftImg8bit/")[1]
            os.makedirs(os.path.dirname(fileSave), exist_ok=True)
            output = cv2.addWeighted(images, 0.4, label_save, 0.6, 0)
            cv2.imwrite(fileSave,output)
Example #2
0
def inference(model, args):
    image_folder = "/media/pandongwei/ExtremeSSD/work_relative/extract_img/2020.10.16_1/"
    video_save_path = "/home/pandongwei/work_repository/erfnet_pytorch/eval/"

    # parameters about saving video
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(video_save_path + 'output_new.avi', fourcc, 10.0, (640, 480))

    cuda = True
    model.eval()

    paths = []
    for root, dirs, files in os.walk(image_folder, topdown=True):
        for file in files:
            image_path = os.path.join(image_folder, file)
            paths.append(image_path)
    paths.sort()
    font = cv2.FONT_HERSHEY_SIMPLEX

    angle_pre = 0
    for i, path in enumerate(paths):
        start_time = time.time()
        image = cv2.imread(path)
        image = (image / 255.).astype(np.float32)

        image = ToTensor()(image).unsqueeze(0)
        if (cuda):
            image = image.cuda()

        input = Variable(image)

        with torch.no_grad():
            output = model(input)

        label = output[0].max(0)[1].byte().cpu().data
        # label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))

        label_save = label_color.numpy()
        label_save = label_save.transpose(1, 2, 0)
        # 加上路径规划
        label_save, angle = perception_to_angle(label_save, angle_pre)
        # 加一个滤波以防止角度突然跳变 TODO
        if abs(angle - angle_pre) > 10:
            angle = angle_pre

        angle_pre = angle
        # label_save.save(filenameSave)
        image = image.cpu().numpy().squeeze(axis=0).transpose(1, 2, 0)
        image = (image * 255).astype(np.uint8)
        output = cv2.addWeighted(image, 0.5, label_save, 0.5, 0)
        cv2.putText(output,str(round(angle,3)),(50,50),cv2.FONT_HERSHEY_SIMPLEX,2,(0,0,0),2)
        #output = np.hstack([label_save, image])
        out.write(output)

        print(i, "  time: %.2f s" % (time.time() - start_time))
    out.release()
def inference(model, image, angle_pre):
    # pre-process
    image = (image / 255.).astype(np.float32)
    image = ToTensor()(image).unsqueeze(0)
    image = image.cuda()
    input = Variable(image)
    # inference
    with torch.no_grad():
        output = model(input)
    # post-process
    label = output[0].max(0)[1].byte().cpu().data
    label_color = Colorize()(label.unsqueeze(0))
    label_save = label_color.numpy()
    label_save = label_save.transpose(1, 2, 0)
    # 加上路径规划
    label_save, angle = perception_to_angle(label_save, angle_pre)

    return angle
Example #4
0
    def callback(self, oimg):

        try:
            #if you want to save images and labels ,please uncomment following codes(No.1 to No.4).
            #NO.1 
            write_image_name = "image_" + str(self.count) + ".jpg"
            
            #No.2 
            write_label_name = "label_" + str(self.count) + ".jpg"

            oimg_b = bytes(oimg.data)
            np_arr = np.fromstring(oimg_b, np.uint8)
            img = cv2.imdecode(np_arr, cv2.IMREAD_COLOR)

            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

            #No.3 
            #cv2.imwrite("/home/amsl/images/output/seg_pub_2/image/" + write_image_name, img)
            
            image = PIL_Image.fromarray(img)
            #image = image.crop((0, 120, 640, 480))
            image = image.crop((0, 10, 640, 330))
            #image = image.resize((1024,512),PIL_Image.NEAREST)
            image = image.resize((1024,512),PIL_Image.NEAREST)

            #img_size = image.shape

            image = ToTensor()(image)
            image = torch.Tensor(np.array([image.numpy()]))

            image = image.cuda()
            
            input_image = Variable(image)
            
            with torch.no_grad():
                output_image = self.model(input_image)
            
            label = output_image[0].max(0)[1].byte().cpu().data
            label_color = Colorize()(label.unsqueeze(0))
            label_pub = ToPILImage()(label_color)
            #label_pub = label_pub.resize((1024, 512),PIL_Image.NEAREST)
            #label_pub = label_pub.resize((1024, 512),PIL_Image.LANCZOS)
            label_pub = np.asarray(label_pub)
            
            #show label.
            #plt.imshow(label_pub)
            #plt.pause(0.001)
            
            #No.4 
            #cv2.imwrite("/home/amsl/images/output/seg_pub_2/label/" + write_label_name, label_pub)

            self.pub_seg.publish(self.bridge.cv2_to_imgmsg(label_pub, "bgr8"))
            print("published") 
            self.count += 1
        
        except CvBridgeError as e:
            print(e)
def vis_using_Colorize(indir_list, outdir):
    indir = indir_list[0]
    # outdir = os.path.join(os.path.split(indir)[0], "vis_labels")
    mkdir_if_not_exist(outdir)

    for one_file in tqdm(os.listdir(indir)):
        fullpath = os.path.join(indir, one_file)
        hard_to_see_img = m.imread(fullpath)
        # outputs = outputs[0, :19].data.max(0)[1]
        # outputs = outputs.view(1, outputs.size()[0], outputs.size()[1])
        outputs = hard_to_see_img  # TODO this should be fixed
        output = Colorize()(outputs)
        output = np.transpose(output.cpu().numpy(), (1, 2, 0))
        img = Image.fromarray(output, "RGB")
        img = img.resize(hard_to_see_img.shape, Image.NEAREST)

        outfn = os.path.join(outdir, one_file)
        plt.savefig(outfn, transparent=True, bbox_inches='tight', pad_inches=0)
        img.save(outfn)
Example #6
0
def ensemble_predict(prob_fns, outfile='sample.png',
                     out_npy_file='sample.npy', out_vis_file='sample_vis.png',
                     method='averaging', out_shape=(2048, 1024)):
    """Output predict file from two npy files by the given method."""
    probs = [np.load(prob_fn) for prob_fn in prob_fns]

    if method == 'averaging':
        prob = sum(probs) / len(probs)
    elif method == 'nms':
        prob = np.max(probs, 0)

    # -- output npy --
    np.save(out_npy_file, prob)

    # -- output label-predict --
    pred = prob[:N_CLASS].argmax(0)
    img = Image.fromarray(np.uint8(pred))
    img = img.resize(out_shape, Image.NEAREST)
    img.save(outfile)

    # -- output vis-predict --

    # ToTensor function: ndarray -> Tensor
    #   * H, W, C -> C, H, W
    #   * 0, 255 -> 0, 1
    # prob.shape == (20, 512, 1024)
    prob_for_tensor = np.transpose(prob[:N_CLASS], (1, 2, 0))
    prob_tensor = ToTensor()(prob_for_tensor)

    pred = prob_tensor.max(0)[1]
    pred = pred.view(1, pred.size()[0], pred.size()[1])

    # Colorize function: Tensor -> Tensor
    vis_tensor = Colorize()(pred)

    # Tensor -> ndarray -> image(save)
    vis = np.transpose(vis_tensor.numpy(), (1, 2, 0))
    vis_img = Image.fromarray(vis, 'RGB')
    vis_img = vis_img.resize(out_shape, Image.NEAREST)
    vis_img.save(out_vis_file)
def talker(args):
    NUM_CLASSES = 4
    color_transform = Colorize(NUM_CLASSES)
    # Load Model
    savedir = f'../save/{args.savedir}'
    if not os.path.exists(savedir):
        os.makedirs(savedir)
    model = ERFNet(NUM_CLASSES)
    if args.cuda:
        model = torch.nn.DataParallel(model).cuda()
    model_dir = args.model_dir

    def load_my_state_dict(model, state_dict):
        # state_dict = state_dict["state_dict"]
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(model_dir))
    model.eval()
    # parameters about saving video
    video_save_path = args.video_save_path
    fourcc = cv2.VideoWriter_fourcc(*'XVID')
    out = cv2.VideoWriter(video_save_path + 'output_new.avi', fourcc, 10.0,
                          (640, 480))

    ic = image_converter(model, out)
    rospy.init_node('talker', anonymous=True)
    try:
        rospy.spin()
    except KeyboardInterrupt:
        print("Shutting down")
    out.release()
Example #8
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(model, state_dict):
        own_state = model.state_dict()

        for a in own_state.keys():
            print(a)
        for a in state_dict.keys():
            print(a)
        print('-----------')

        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)

        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(cityscapes(args.datadir,
                                   input_transform_cityscapes,
                                   subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    time_all = []
    with torch.no_grad():
        for step, (images, filename) in enumerate(loader):

            images = images.cuda()
            start_time = time.time()
            outputs = model(images)
            fwt = time.time() - start_time
            time_all.append(fwt)

            label = outputs[0].cpu().max(0)[1].data.byte()
            label_color = Colorize()(label.unsqueeze(0))

            filenameSave = "./save_color/" + filename[0].split(
                "leftImg8bit/")[1]
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            label_save = ToPILImage()(label_color)

            label_save.save(filenameSave)

            print(step, filenameSave)
            print("FPS (Mean: %.4f)" % (1.0000 /
                                        (sum(time_all) / len(time_all))))
Example #9
0
from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, Pad
from torchvision.transforms import ToTensor, ToPILImage

from dataset import VOC12, cityscapes
from transform import Relabel, ToLabel, Colorize
from visualize import Dashboard

import importlib
from iouEval import iouEval, getColorEntry

from shutil import copyfile

NUM_CHANNELS = 3
NUM_CLASSES = 34 #pascal=22, cityscapes=20

color_transform = Colorize(NUM_CLASSES)
image_transform = ToPILImage()

#Augmentations - different function implemented to perform random augments on both image and target
class MyCoTransform(object):
    def __init__(self, enc, augment=True, height=512):
        self.enc=enc
        self.augment = augment
        self.height = height
        pass
    def __call__(self, input, target):
        # do something to both images
        input =  Resize(self.height, Image.BILINEAR)(input)
        target = Resize(self.height, Image.NEAREST)(target)

        if(self.augment):

from dataset import CityScapes,CityScapes_validation
from network import LinkNet34
from criterion import CrossEntropyLoss2d
from transform import Relabel, ToLabel, Colorize
#import deeplab_resnet
import torch.nn.functional as F
#from accuracy_metrics import pixel_accuracy,mean_accuracy,mean_IU
from accuracy_metrics import pixel_accuracy

NUM_CHANNELS = 3
NUM_CLASSES = 35  #6 for brats


color_transform = Colorize()
image_transform = ToPILImage()
input_transform = Compose([
	#CenterCrop(256),
	#Scale(240),
	Resize((512,1024),Image.NEAREST),
	ToTensor(),
	Normalize([73.158359/255.0, 82.908917/255.0, 72.392398/255.0], [11.847663/255.0, 10.710858/255.0, 10.358618/255.0]),
])

input_transform1 = Compose([
	#CenterCrop(256),
	ToTensor(),
])

target_transform = Compose([
Example #11
0
optimizer_feat = torch.optim.Adam(res101.parameters(), lr=1e-4)

for t in range(10):
    for i, (img, label) in enumerate(loader):
        img = img.cuda()
        label = label[0].cuda()
        label = Variable(label)
        input = Variable(img)

        feats = res101(input)
        output = seg(feats)

        seg.zero_grad()
        res101.zero_grad()
        loss = criterion(output, label)
        loss.backward()
        optimizer_feat.step()
        optimizer_seg.step()

        ## see
        input = make_image_grid(img, mean, std)
        label = make_label_grid(label.data)
        label = Colorize()(label).type(torch.FloatTensor)
        output = make_label_grid(torch.max(output, dim=1)[1].data)
        output = Colorize()(output).type(torch.FloatTensor)
        writer.add_image('image', input, i)
        writer.add_image('label', label, i)
        writer.add_image('pred', output, i)
        writer.add_scalar('loss', loss.data[0], i)

        print "epoch %d step %d, loss=%.4f" % (t, i, loss.data.cpu()[0])
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    fourcc = cv2.VideoWriter_fourcc(*'MP4V')  # Save as video
    out = cv2.VideoWriter('output.mp4', fourcc, 20.0, (640, 352))

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(
        model, torch.load(weightspath, map_location=torch.device('cpu')))
    print("Model and weights LOADED successfully")

    model.eval()

    # loader = DataLoader(cityscapes(args.datadir, input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
    #     num_workers=args.num_workers, batch_size=args.batch_size, shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    # if (args.visualize):
    #     vis = visdom.Visdom()

    # for step, (images, labels, filename, filenameGt) in enumerate(loader):
    # cap = cv2.VideoCapture(0)
    cap = cv2.VideoCapture('project_video_trimmed.mp4')

    while (True):
        # Capture frame-by-frame
        ret, images = cap.read()
        # print(images.shape)

        images = trans(images)
        images = images.float()
        images = images.view((1, 3, 352, 640))  # vidoe

        # Our operations on the frame come here
        if (not args.cpu):
            images = images.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))
        frame = label_color.numpy().transpose(1, 2, 0)

        # label_save = ToPILImage()(label_color)
        # label_save.save("result_1.png")

        # Display the resulting frame
        cv2.imshow('Segmented Image', frame)
        out.write(frame)  # To save video file

        if cv2.waitKey(1) & 0xFF == ord('q'):
            break

    cap.release()
    out.release()
    cv2.destroyAllWindows()

    if (args.visualize):
        vis.image(label_color.numpy())
    print(step, filenameSave)
Example #13
0
def main(args, get_dataset):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    NUM_CLASSES = get_dataset.num_labels
    model = Net(NUM_CLASSES, args.em_dim, args.resnet)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()

        for a in own_state.keys():
            print(a)
        for a in state_dict.keys():
            print(a)
        print('-----------')

        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)

        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(cityscapes(args.datadir,
                                   input_transform_cityscapes,
                                   subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    with torch.set_grad_enabled(False):
        for step, (images, filename) in enumerate(loader):

            images = images.cuda()

            outputs = model(images, enc=False)
            outputs = outputs['MAP']

            label = outputs[0].cpu().max(0)[1].data.byte()
            label_color = Colorize()(label.unsqueeze(0))
            filenameSave = "./save_color/" + filename[0].split(
                "leftImg8bit/")[1]
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            label_save = ToPILImage()(label_color)
            label_save.save(filenameSave)

            label = outputs[1].cpu().max(0)[1].data.byte()
            label_color = Colorize()(label.unsqueeze(0))
            filenameSave = "./save_color/" + filename[1].split(
                "leftImg8bit/")[1]
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            label_save = ToPILImage()(label_color)
            label_save.save(filenameSave)

            print(step, filenameSave)
Example #14
0
def main(args):
    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    # model.load_state_dict(torch.load(args.state))
    # model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
        model, state_dict
    ):  # custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    # loader = DataLoader(
    #     cityscapes('/home/liqi/PycharmProjects/LEDNet/4.png', input_transform_cityscapes, target_transform_cityscapes, subset=args.subset),
    #     num_workers=args.num_workers, batch_size=1 ,shuffle=False)
    input_transform_cityscapes = Compose([
        Resize((512, 1024), Image.BILINEAR),
        ToTensor(),
        # Normalize([.485, .456, .406], [.229, .224, .225]),
    ])
    name = "4.png"
    with open(image_path_city('/home/gongyiqun/images', name), 'rb') as f:
        images = load_image(f).convert('RGB')

        images = input_transform_cityscapes(images)
    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    if (not args.cpu):
        images = images.cuda()
        # labels = labels.cuda()
    a = torch.unsqueeze(images, 0)
    inputs = Variable(a)
    # targets = Variable(labels)
    with torch.no_grad():
        outputs = model(inputs)

    label = outputs[0].max(0)[1].byte().cpu().data
    # label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
    label_color = Colorize()(label.unsqueeze(0))

    filenameSave = "./save_color/" + "Others/" + name
    os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
    # image_transform(label.byte()).save(filenameSave)

    label_save = ToPILImage()(label_color)
    label_save = label_save.resize((1241, 376), Image.BILINEAR)
    # label_save = cv2.resize(label_save, (376, 1224),interpolation=cv2.INTER_AREA)
    label_save.save(filenameSave)

    if (args.visualize):
        vis.image(label_color.numpy())
Example #15
0
def main(args, get_dataset):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    NUM_CLASSES = get_dataset.num_labels
    model = Net(NUM_CLASSES, args.em_dim, args.resnet)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()

        for a in own_state.keys():
            print(a)
        for a in state_dict.keys():
            print(a)
        print('-----------')

        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)

        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(cityscapes(args.datadir,
                                   input_transform_cityscapes,
                                   subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    log_dir = "./save_eval_iou/"
    os.makedirs(log_dir, exist_ok=True)
    logger = create_logger(log_dir)
    fused_iou_eval = FusedIoU(NUM_CLASSES['MAP'], WILDPASS_IN_MAP_ID,
                              REMAP_MAP_IDD, WILDPASS_REMAP_ID, logger)

    with torch.set_grad_enabled(False):
        step = 0
        for step, (images, filename) in enumerate(loader):

            images = images.cuda()

            outputs = model(images, enc=False)

            pred_MAP = outputs['MAP'].data.cpu().numpy()
            pred_IDD = outputs['IDD20K'].data.cpu().numpy()

            # --- evaluate
            label = np.asarray(
                Image.open(filename[0].replace('/leftImg8bit/', '/gtFine/')))
            label = np.expand_dims(label, axis=0)

            if args.is_fuse:
                outputs = fused_iou_eval.fuse(label, pred_MAP, pred_IDD,
                                              filename)
            else:
                fused_iou_eval.add_batch(np.argmax(pred_MAP, axis=1),
                                         label)  # only MAP
                #outputs = outputs['MAP']

            outputs = torch.from_numpy(outputs[None, ...])
            label_color = Colorize()(outputs)

            filenameSave = "./save_color/" + filename[0].split(
                "leftImg8bit/")[1]
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            label_save = ToPILImage()(label_color)
            label_save.save(filenameSave)

            print(step, filenameSave)

            print('processed step {}, file {}.'.format(
                step, filename[0].split("leftImg8bit/val/cs/")[1]))

        logger.info('========== evaluator using onehot ==========')
        miou, iou, macc, acc = fused_iou_eval.get_iou()
        logger.info('-----------Acc of each classes-----------')
        for i, c_acc in enumerate(acc):
            name = WILDPASS_IN_MAP_NAME[i] if len(acc) == len(
                WILDPASS_IN_MAP_NAME) else str(i)
            logger.info('ID= {:>10s}: {:.2f}'.format(name, c_acc * 100.0))
        logger.info("Acc of {} images :{:.2f}".format(str(step + 1),
                                                      macc * 100.0))
        logger.info('-----------IoU of each classes-----------')
        for i, c_iou in enumerate(iou):
            name = WILDPASS_IN_MAP_NAME[i] if len(iou) == len(
                WILDPASS_IN_MAP_NAME) else str(i)
            logger.info('ID= {:>10s}: {:.2f}'.format(name, c_iou * 100.0))
        logger.info("mIoU of {} images :{:.2f}".format(str(step + 1),
                                                       miou * 100.0))
def train(args, rmodel, model, enc=False):
    best_acc = 0
    weight = classWeights(NUM_CLASSES)
    assert os.path.exists(
        args.datadir), "Error: datadir (dataset directory) could not be loaded"

    co_transform = MyCoTransform(augment=True, height=args.height)
    co_transform_val = MyCoTransform(augment=False, height=args.height)
    dataset_train = cityscapes(args.datadir, co_transform, 'train')
    dataset_val = cityscapes(args.datadir, co_transform_val, 'val')

    loader = DataLoader(dataset_train,
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=True)
    loader_val = DataLoader(dataset_val,
                            num_workers=args.num_workers,
                            batch_size=args.batch_size,
                            shuffle=False)

    if args.cuda:
        weight = weight.cuda()
    rcriterion = torch.nn.L1Loss()

    savedir = '/home/shyam.nandan/NewExp/F_erfnet_pytorch_ours_w_gt_v2_multiply/save/' + args.savedir  #change path

    if (enc):
        automated_log_path = savedir + "/automated_log_encoder.txt"
        modeltxtpath = savedir + "/model_encoder.txt"
    else:
        automated_log_path = savedir + "/automated_log.txt"
        modeltxtpath = savedir + "/model.txt"

    if (not os.path.exists(automated_log_path)):
        with open(automated_log_path, "a") as myfile:
            myfile.write(
                "Epoch\t\tTrain-loss\t\tTest-loss\t\tTrain-IoU\t\tTest-IoU\t\tlearningRate"
            )

    with open(modeltxtpath, "w") as myfile:
        myfile.write(str(model))

    optimizer = Adam(model.parameters(),
                     5e-4, (0.9, 0.999),
                     eps=1e-08,
                     weight_decay=2e-4)  ##
    roptimizer = Adam(rmodel.parameters(), 2e-4,
                      (0.9, 0.999))  ## restoration scheduler

    start_epoch = 1
    scheduler = lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.99)
    rscheduler = lr_scheduler.StepLR(roptimizer, step_size=30,
                                     gamma=0.5)  ## Restoration schedular

    for epoch in range(start_epoch, args.num_epochs + 1):
        print("----- TRAINING - EPOCH", epoch, "-----")

        scheduler.step()  ## scheduler 2
        rscheduler.step()

        epoch_loss = []
        time_train = []

        doIouTrain = args.iouTrain
        doIouVal = args.iouVal

        if (doIouTrain):
            iouEvalTrain = iouEval(NUM_CLASSES)

        usedLr = 0
        rusedLr = 0
        for param_group in optimizer.param_groups:
            print("Segmentation LEARNING RATE: ", param_group['lr'])
            usedLr = float(param_group['lr'])
        for param_group in roptimizer.param_groups:
            print("Restoration LEARNING RATE: ", param_group['lr'])
            rusedLr = float(param_group['lr'])

        model.eval()
        epoch_loss_val = []
        time_val = []

        if (doIouVal):
            iouEvalVal = iouEval(NUM_CLASSES)

        for step, (timages, images, labels, filename) in enumerate(loader_val):
            start_time = time.time()
            if args.cuda:
                images = images.cuda()
                labels = labels.cuda()
                timages = timages.cuda()

            inputs = Variable(
                timages, volatile=True
            )  #volatile flag makes it free backward or outputs for eval
            itargets = Variable(images, volatile=True)
            targets = Variable(labels, volatile=True)

            ss_inputs = rmodel(inputs, flag=0, r_fb1=0, r_fb2=0)

            outs = model(ss_inputs, only_encode=enc)

            tminus_outs = outs.detach()
            tplus_outs = outs.detach()

            for num_feedback in range(3):

                optimizer.zero_grad()
                roptimizer.zero_grad()

                ss_inputs = rmodel(inputs,
                                   flag=1,
                                   r_fb1=(tplus_outs - tminus_outs),
                                   r_fb2=ss_inputs.detach())

                loss = rcriterion(ss_inputs, itargets)

                outs = model(ss_inputs.detach(), only_encode=enc)

                tminus_outs = tplus_outs
                tplus_outs = outs.detach()

            outputs = outs
            del outs, tminus_outs, tplus_outs
            gc.collect()
            Gamma = [0, 0, 0]
            Alpha = [1, 1, 1]
            loss = CB_iFl(outputs,
                          targets[:, 0],
                          weight,
                          gamma=Gamma[0],
                          alpha=Alpha[0])
            epoch_loss_val.append(loss.data[0])
            time_val.append(time.time() - start_time)

            if (doIouVal):
                #start_time_iou = time.time()
                iouEvalVal_img = iouEval(NUM_CLASSES)
                iouEvalVal_img.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)

                iouEvalVal.addBatch(
                    outputs.max(1)[1].unsqueeze(1).data, targets.data)

                #print ("Time to add confusion matrix: ", time.time() - start_time_iou)
                label_color = Colorize()(
                    outputs[0].max(0)[1].byte().cpu().data.unsqueeze(0))
                label_save = ToPILImage()(label_color)

                filenameSave = '../save_color_restored_joint_afl_CBFL/' + filename[
                    0].split('/')[-2]

                im_iou, _ = iouEvalVal_img.getIoU()

                if not os.path.exists(filenameSave):
                    os.makedirs(filenameSave)
            #Uncomment to save output
            #label_save.save(filenameSave+ '/' + str(" %6.4f " %im_iou[0].data.numpy()) + '_' + filename[0].split('/')[-1])

            if args.steps_loss > 0 and step % args.steps_loss == 0:
                average = sum(epoch_loss_val) / len(epoch_loss_val)
                print('Val loss:  ', average, 'Epoch:  ', epoch, 'Step:  ',
                      step)

        average_epoch_loss_val = sum(epoch_loss_val) / len(epoch_loss_val)

        iouVal = 0
        if (doIouVal):
            iouVal, iou_classes = iouEvalVal.getIoU()
            iouStr = getColorEntry(iouVal) + '{:0.2f}'.format(
                iouVal * 100) + '\033[0m'
            print(iouVal, iou_classes, iouStr)

    return (model)
Example #17
0
def main(args):
    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")
    else:
        print("Loading model: " + modelpath)
        print("Loading weights: " + weightspath)

    # Import ERFNET
    model = ERFNet(NUM_CLASSES)
    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    # Set model to Evaluation mode
    model.eval()

    # Setup the dataset loader
    ### RELLIS-3D Dataloader
    enc = False
    loader_test = custom_datasets.setup_loaders(args, enc)
    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, img_name, _) in enumerate(loader_test):
        start_time = time.time()
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        label_color = Colorize()(label.unsqueeze(0))

        eval_save_path = "./save_colour_rellis/"
        if not os.path.exists(eval_save_path):
            os.makedirs(eval_save_path)

        _, file_name = os.path.split(img_name[0])
        file_name = file_name + ".png"

        #image_transform(label.byte()).save(filenameSave)
        label_save = ToPILImage()(label_color)
        label_save.save(os.path.join(eval_save_path, file_name))

        if (args.visualize):
            vis.image(label_color.numpy())
        if step != 0:  #first run always takes some time for setup
            fwt = time.time() - start_time
            time_train.append(fwt)
            print("Forward time per img (b=%d): %.3f (Mean: %.3f)" %
                  (args.batch_size, fwt / args.batch_size,
                   sum(time_train) / len(time_train) / args.batch_size))

        print(step, os.path.join(eval_save_path, file_name))
def main(args):

    modelpath = args.loadDir + args.loadModel
    #weightspath = args.loadDir + args.loadWeights #TODO
    weightspath = "../save/feriburgForest_3/model_best.pth"
    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    #Import ERFNet model from the folder
    #Net = importlib.import_module(modelpath.replace("/", "."), "ERFNet")
    model = ERFNet(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(freiburgForest(args.datadir,
                                       input_transform_cityscapes,
                                       target_transform_cityscapes,
                                       subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)
        print(outputs.shape)
        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))

        filenameSave = "./freiburgforest_1/" + filename[0].split(
            "freiburg_forest_annotated/")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        print(label_color.shape)
        # label_save = ToPILImage()(label_color)
        label_save = label_color.numpy()
        label_save = label_save.transpose(1, 2, 0)
        print(label_save.shape)
        # label_save.save(filenameSave)
        images = images.cpu().numpy().squeeze(axis=0)
        images = images.transpose(1, 2, 0)

        # print(images.shape)
        # print(label_save.shape)
        plt.figure(figsize=(10.24, 5.12), dpi=100)
        plt.imshow(images)
        plt.imshow(label_save, alpha=0.6)
        plt.axis('off')
        # plt.show()
        plt.savefig(filenameSave, dpi=100)
        plt.close()

        if (args.visualize):
            vis.image(label_color.numpy())
        print(step, filenameSave)
NUM_CHANNELS = 3
# NUM_CLASSES = 20 #pascal=22, cityscapes=20
NUM_HISTOGRAMS = 5
NUM_IMG_PER_EPOCH = 5
# Optimizer params.
LEARNING_RATE=5e-4
BETAS=(0.9, 0.999)
OPT_EPS=1e-08
WEIGHT_DECAY=1e-6

DISCOUNT_RATE_START=0.1
DISCOUNT_RATE=0.01
MAX_CONSISTENCY_EPOCH=30
DISCOUNT_RATE_START_EPOCH=50

color_transform_target = Colorize(1.0, 2.0, remove_negative=True, extend=True, white_val=1.0)  # min_val, max_val, remove negative
color_transform_output = Colorize(1.0, 2.0, remove_negative=False, extend=True, white_val=1.0)  # Automatic color based on tensor min/max val
# color_transform_output = ColorizeMinMax()  # Automatic color based on tensor min/max val
image_transform = ToPILImage()



#Augmentations - different function implemented to perform random augments on both image and target
class MyCoTransform(object):
    def __init__(self, enc, augment=True, height=512):
        self.enc=enc
        self.augment = augment
        self.height = height

        self.rotation_angle = 5.0
        self.affine_angle = 5.0
Example #20
0
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    resnet = resnet18(pretrained=True, efficient=False, use_bn=True)
    model = Net(resnet, size=(512, 1024 * 4), num_classes=NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()

        for a, b in zip(own_state.keys(), state_dict.keys()):
            print(a, '      ', b)
        print('-----------')

        for name, param in state_dict.items():
            # print('#####', name)
            name = name[7:]
            if name not in own_state:
                print('{} not in own_state'.format(name))
                continue
            #if name not in except_list:
            own_state[name].copy_(param)

        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(cityscapes(args.datadir,
                                   input_transform_cityscapes,
                                   subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    time_all = []
    with torch.no_grad():
        for step, (images, filename) in enumerate(loader):

            images = images.cuda()
            start_time = time.time()
            outputs = model(images)
            fwt = time.time() - start_time
            time_all.append(fwt)

            label = outputs[0].cpu().max(0)[1].data.byte()
            label_color = Colorize()(label.unsqueeze(0))

            filenameSave = "./save_color/" + filename[0].split(
                "leftImg8bit/")[1]
            os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
            label_save = ToPILImage()(label_color)

            label_save.save(filenameSave)

            print(step, filenameSave)
            print("Forward time per img (Mean: %.4f)" %
                  (1 / (sum(time_all) / len(time_all))))
Example #21
0
model.load_state_dict(torch.load("./pth/fcn-deconv-40.pth"))
model.eval()


# 10 13 48 86 101
img = Image.open("./data/VOC2012test/JPEGImages/2008_000101.jpg").convert("RGB")
original_size = img.size
img.save("original.png")
img = img.resize((256, 256), Image.BILINEAR)
img = ToTensor()(img)
img = Variable(img).unsqueeze(0)
outputs = model(img)
# 22 256 256
for i, output in enumerate(outputs):
    output = output[0].data.max(0)[1]
    output = Colorize()(output)
    output = np.transpose(output.numpy(), (1, 2, 0))
    img = Image.fromarray(output, "RGB")
    if i == 0:
        img = img.resize(original_size, Image.NEAREST)
    img.save("test-%d.png" % i)

'''

for index, (imgs, name, size) in tqdm(enumerate(testloader)):
    imgs = Variable(imgs.cuda())
    outputs = model(imgs)

    output = outputs[0][0].data.max(0)[1]
    output = Colorize()(output)
    print(output)
with torch.no_grad():

    for sample in tqdm(dataset_it):

        im = sample['image']
        instances = sample['instance'].squeeze()
        
        output = model(im)
        instance_map, predictions = cluster.cluster(output[0], threshold=0.9)

        visualizer.display(im, 'image')

            
        visualizer.display(instance_map.cpu(), 'pred')
        label1 = instance_map.cpu()
        label_color = Colorize()(label1.unsqueeze(0))

        sigma = output[0][2].cpu()
        sigma = (sigma - sigma.min())/(sigma.max() - sigma.min())
        sigma[instances == 0] = 0
        visualizer.display(sigma, 'sigma')
        label2=sigma
        label2.squeeze_()

        seed = torch.sigmoid(output[0][3]).cpu()
        visualizer.display(seed, 'seed')
        label3 = seed.cpu()
        label3.squeeze_()


Example #23
0
    comp = model.get_computations(True)
    print(comp)
    print(sum(comp))

    if cuda:
        model = model.cuda()

    model.eval() # Set in evaluation mode

    Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor


    print ('\nPerforming object detection:')
    bar = progressbar.ProgressBar(0, len(valloader), redirect_stdout=False)
    for batch_i, (imgs, targets) in enumerate(valloader):
        # Configure input
        input_imgs = imgs.type(Tensor)

        # Get detections
        with torch.no_grad():
            pred = model(input_imgs)
            _, predClass = torch.max(pred, 1)
            mask = Colorize(predClass.cpu().squeeze()).permute(1,2,0).numpy()

        cv2.imwrite('output/%d.png' % (batch_i),mask)

        # Log progress
        bar.update(batch_i)

    bar.finish()
Example #24
0
    imgs = Variable(imgs.cuda())
    start_time = time.time()
    outputs_4, outputs_2, outputs = model(imgs)
    end_time = time.time()
    if j > 9:
        avr_time += (end_time - start_time)
    if j == 109:
        print avr_time/100.0
        exit()
    # 22 256 256
    filename = list(names)[0]
    for i, (output_4, output_2, output) in enumerate(zip(outputs_4, outputs_2, outputs)):
        output_4 = output_4.data.max(0)[1]

        output_4 = Colorize()(output_4)
        output_4 = np.transpose(output_4.numpy(), (1, 2, 0))
        img_4 = Image.fromarray(output_4, "RGB")
 
        output_2 = output_2.data.max(0)[1]

        output_2 = Colorize()(output_2)
        output_2 = np.transpose(output_2.numpy(), (1, 2, 0))
        img_2 = Image.fromarray(output_2, "RGB")

        output = output.data.max(0)[1]

        output = Colorize()(output)
        output = np.transpose(output.numpy(), (1, 2, 0))
        img = Image.fromarray(output, "RGB")
        if i == 0:
                outputs[cnt] = lab[0]
                outputs[cnt + 1] = lab[1]
                cnt += 2

            beg = time.clock()
            pred = model(inputs)
            t += time.clock() - beg
            _, predClass = torch.max(pred, 1)

        bSize = inputs.data.size()[0]

        running_acc += torch.sum(predClass == outputs).item() * outSize * 100

        for j in range(bSize):
            img = Image.fromarray(
                Colorize(predClass.data[j]).permute(1, 2,
                                                    0).numpy().astype('uint8'))
            img.save(outDir + "%d.png" % (imgCnt + j))
        imgCnt += bSize

        maskPred = torch.zeros(numClass, bSize, int(labSize[0]),
                               int(labSize[1])).long()
        maskLabel = torch.zeros(numClass, bSize, int(labSize[0]),
                                int(labSize[1])).long()
        for currClass in range(numClass):
            maskPred[currClass] = predClass == currClass
            maskLabel[currClass] = outputs == currClass

        for labIdx in range(numClass):
            labCnts[labIdx] += torch.sum(maskLabel[labIdx]).item()
            for predIdx in range(numClass):
                inter = torch.sum(maskPred[predIdx] & maskLabel[labIdx]).item()
Example #26
0
from torchvision.transforms import Compose, CenterCrop, Normalize, Resize, Pad
from torchvision.transforms import ToTensor, ToPILImage
import matplotlib.pyplot as plt
from dataset import geoMat
from transform import Colorize
from visualize import Dashboard
from tensorboardX import SummaryWriter

import importlib
from iouEval import iouEval, getColorEntry

from shutil import copyfile

mean_and_var = 2

color_transform = Colorize(mean_and_var)


#Augmentations - different function implemented to perform random augments on both image and target
class MyCoTransform(object):
    def __init__(self, augment=True, rescale=True, size=104):

        self.augment = augment
        self.size = (size, size)
        self.rescale = rescale

        pass

    def __call__(self, input, target):
        # do something to both images
        # input = Resize(self.size, Image.BILINEAR)(input)
def main(args):

    modelpath = args.loadDir + args.loadModel
    weightspath = args.loadDir + args.loadWeights

    print("Loading model: " + modelpath)
    print("Loading weights: " + weightspath)

    model = Net(NUM_CLASSES)

    model = torch.nn.DataParallel(model)
    if (not args.cpu):
        model = model.cuda()

    #model.load_state_dict(torch.load(args.state))
    #model.load_state_dict(torch.load(weightspath)) #not working if missing key

    def load_my_state_dict(
            model, state_dict
    ):  #custom function to load model when not all dict elements
        own_state = model.state_dict()
        for name, param in state_dict.items():
            if name not in own_state:
                continue
            own_state[name].copy_(param)
        return model

    model = load_my_state_dict(model, torch.load(weightspath))
    print("Model and weights LOADED successfully")

    model.eval()

    if (not os.path.exists(args.datadir)):
        print("Error: datadir could not be loaded")

    loader = DataLoader(cityscapes(args.datadir,
                                   input_transform_cityscapes,
                                   target_transform_cityscapes,
                                   subset=args.subset),
                        num_workers=args.num_workers,
                        batch_size=args.batch_size,
                        shuffle=False)

    # For visualizer:
    # must launch in other window "python3.6 -m visdom.server -port 8097"
    # and access localhost:8097 to see it
    if (args.visualize):
        vis = visdom.Visdom()

    for step, (images, labels, filename, filenameGt) in enumerate(loader):
        if (not args.cpu):
            images = images.cuda()
            #labels = labels.cuda()

        inputs = Variable(images)
        #targets = Variable(labels)
        with torch.no_grad():
            outputs = model(inputs)

        label = outputs[0].max(0)[1].byte().cpu().data
        #label_cityscapes = cityscapes_trainIds2labelIds(label.unsqueeze(0))
        label_color = Colorize()(label.unsqueeze(0))

        filenameSave = "./save_color/" + filename[0].split("leftImg8bit/")[1]
        os.makedirs(os.path.dirname(filenameSave), exist_ok=True)
        #image_transform(label.byte()).save(filenameSave)
        label_save = ToPILImage()(label_color)
        label_save.save(filenameSave)

        if (args.visualize):
            vis.image(label_color.numpy())
        print(step, filenameSave)
Example #28
0
        loss.backward()
        optimizer_seg.step()
        optimizer_feat.step()
        print "epoch is:[{}|{}],index is:[{}|{}],loss:{}".\
                    format(epoch,epoch_num,i,len(dataloader),loss)

    win = visutils.visualize_loss(epoch, loss.cpu().detach(), env, win)

    if epoch % 40 == 0:
        #save model
        torch.save(vgg16.state_dict(),
                   '%s/vgg16_%03d.pkl' % (result_directory, epoch))
        torch.save(Seg.state_dict(),
                   '%s/Seg_%03d.pkl' % (result_directory, epoch))
        #save result

        input = make_image_grid(img, mean, std)
        label = make_label_grid(labels.data)
        label = Colorize()(label).type(torch.FloatTensor)
        output = make_label_grid(torch.max(logits, dim=1)[1].data)
        output = Colorize()(output).type(torch.FloatTensor)

        vutils.save_image(
            label.cpu().detach(),
            '%s/labels_epoch_%03d.png' % (result_directory, epoch))
        vutils.save_image(
            output.cpu().detach(),
            '%s/output_epoch_%03d.png' % (result_directory, epoch))
        vutils.save_image(input.cpu().detach(),
                          '%s/img_epoch_%03d.png' % (result_directory, epoch))