Пример #1
0
def getModel(sourcemodel):
    net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
        n_classes=20,
        hidden_layers=128,
        source_classes=7,
    )
    x = torch.load(sourcemodel)
    net.load_source_model(x)
    net.cuda()
    return net
Пример #2
0
def inference_batch(sourcemodel, imgPaths):
    # launch gpu model
    net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
        n_classes=20,
        hidden_layers=128,
        source_classes=7,
    )
    x = torch.load(sourcemodel)
    net.load_source_model(x)
    net.cuda()
    n = len(imgPaths)
    for i, imgpath in enumerate(imgPaths):
        print('Running Segmentation on {}: No. {}/{}'.format(
            imgpath, i + 1, n))
        dirname, basename = os.path.split(imgpath)
        basename = '.'.join(basename.split('.')[:-1])
        inferenceWrite(net, imgpath, dirname, basename)
Пример #3
0

if __name__ == "__main__":
    """argparse begin"""
    parser = argparse.ArgumentParser()
    # parser.add_argument('--loadmodel',default=None,type=str)
    parser.add_argument("--loadmodel", default="", type=str)
    parser.add_argument("--img_path", default="", type=str)
    parser.add_argument("--output_path", default="", type=str)
    parser.add_argument("--output_name", default="", type=str)
    parser.add_argument("--use_gpu", default=1, type=int)
    opts = parser.parse_args()

    net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
        n_classes=20,
        hidden_layers=128,
        source_classes=7,
    )
    if not opts.loadmodel == "":
        x = torch.load(opts.loadmodel)
        net.load_source_model(x)
        print("load model:", opts.loadmodel)
    else:
        print("no model load !!!!!!!!")
        raise RuntimeError("No model!!!!")

    if opts.use_gpu > 0:
        net.cuda()
        use_gpu = True
    else:
        use_gpu = False
def main(opts):
    adj2_ = torch.from_numpy(graph.cihp2pascal_nlp_adj).float()
    adj2_test = (adj2_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7,
                                                        20).cuda().transpose(
                                                            2, 3))

    adj1_ = Variable(
        torch.from_numpy(graph.preprocess_adj(graph.pascal_graph)).float())
    adj3_test = adj1_.unsqueeze(0).unsqueeze(0).expand(1, 1, 7, 7).cuda()

    cihp_adj = graph.preprocess_adj(graph.cihp_graph)
    adj3_ = Variable(torch.from_numpy(cihp_adj).float())
    adj1_test = adj3_.unsqueeze(0).unsqueeze(0).expand(1, 1, 20, 20).cuda()

    p = OrderedDict()  # Parameters to include in report
    p["trainBatch"] = opts.batch  # Training batch size
    p["nAveGrad"] = 1  # Average the gradient of several iterations
    p["lr"] = opts.lr  # Learning rate
    p["lrFtr"] = 1e-5
    p["lraspp"] = 1e-5
    p["lrpro"] = 1e-5
    p["lrdecoder"] = 1e-5
    p["lrother"] = 1e-5
    p["wd"] = 5e-4  # Weight decay
    p["momentum"] = 0.9  # Momentum
    p["epoch_size"] = 10  # How many epochs to change learning rate
    p["num_workers"] = opts.numworker
    backbone = "xception"  # Use xception or resnet as feature extractor,

    with open(opts.txt_file, "r") as f:
        img_list = f.readlines()

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split("/")[-1]
    runs = glob.glob(os.path.join(save_dir_root, "run", "run_*"))
    for r in runs:
        run_id = int(r.split("_")[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    # run_id = int(runs[-1].split('_')[-1]) + 1 if runs else 0

    # Network definition
    if backbone == "xception":
        net = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
            n_classes=opts.classes,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
        )
    elif backbone == "resnet":
        # net = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    if gpu_id >= 0:
        net.cuda()

    # net load weights
    if not opts.loadmodel == "":
        x = torch.load(opts.loadmodel)
        net.load_source_model(x)
        print("load model:", opts.loadmodel)
    else:
        print("no model load !!!!!!!!")

    ## multi scale
    scale_list = [1, 0.5, 0.75, 1.25, 1.5, 1.75]
    testloader_list = []
    testloader_flip_list = []
    for pv in scale_list:
        composed_transforms_ts = transforms.Compose(
            [tr.Scale_(pv),
             tr.Normalize_xception_tf(),
             tr.ToTensor_()])

        composed_transforms_ts_flip = transforms.Compose([
            tr.Scale_(pv),
            tr.HorizontalFlip(),
            tr.Normalize_xception_tf(),
            tr.ToTensor_(),
        ])

        voc_val = cihp.VOCSegmentation(split="test",
                                       transform=composed_transforms_ts)
        voc_val_f = cihp.VOCSegmentation(split="test",
                                         transform=composed_transforms_ts_flip)

        testloader = DataLoader(voc_val,
                                batch_size=1,
                                shuffle=False,
                                num_workers=p["num_workers"])
        testloader_flip = DataLoader(voc_val_f,
                                     batch_size=1,
                                     shuffle=False,
                                     num_workers=p["num_workers"])

        testloader_list.append(copy.deepcopy(testloader))
        testloader_flip_list.append(copy.deepcopy(testloader_flip))

    print("Eval Network")

    if not os.path.exists(opts.output_path + "cihp_output_vis/"):
        os.makedirs(opts.output_path + "cihp_output_vis/")
    if not os.path.exists(opts.output_path + "cihp_output/"):
        os.makedirs(opts.output_path + "cihp_output/")

    start_time = timeit.default_timer()
    # One testing epoch
    total_iou = 0.0
    net.eval()
    for ii, large_sample_batched in enumerate(
            zip(*testloader_list, *testloader_flip_list)):
        print(ii)
        # 1 0.5 0.75 1.25 1.5 1.75 ; flip:
        sample1 = large_sample_batched[:6]
        sample2 = large_sample_batched[6:]
        for iii, sample_batched in enumerate(zip(sample1, sample2)):
            inputs, labels = sample_batched[0]["image"], sample_batched[0][
                "label"]
            inputs_f, _ = sample_batched[1]["image"], sample_batched[1][
                "label"]
            inputs = torch.cat((inputs, inputs_f), dim=0)
            if iii == 0:
                _, _, h, w = inputs.size()
            # assert inputs.size() == inputs_f.size()

            # Forward pass of the mini-batch
            inputs, labels = Variable(inputs,
                                      requires_grad=False), Variable(labels)

            with torch.no_grad():
                if gpu_id >= 0:
                    inputs, labels = inputs.cuda(), labels.cuda()
                # outputs = net.forward(inputs)
                # pdb.set_trace()
                outputs = net.forward(inputs, adj1_test.cuda(),
                                      adj3_test.cuda(), adj2_test.cuda())
                outputs = (outputs[0] +
                           flip(flip_cihp(outputs[1]), dim=-1)) / 2
                outputs = outputs.unsqueeze(0)

                if iii > 0:
                    outputs = F.upsample(outputs,
                                         size=(h, w),
                                         mode="bilinear",
                                         align_corners=True)
                    outputs_final = outputs_final + outputs
                else:
                    outputs_final = outputs.clone()
        ################ plot pic
        predictions = torch.max(outputs_final, 1)[1]
        prob_predictions = torch.max(outputs_final, 1)[0]
        results = predictions.cpu().numpy()
        prob_results = prob_predictions.cpu().numpy()
        vis_res = decode_labels(results)

        parsing_im = Image.fromarray(vis_res[0])
        parsing_im.save(opts.output_path +
                        "cihp_output_vis/{}.png".format(img_list[ii][:-1]))
        cv2.imwrite(
            opts.output_path + "cihp_output/{}.png".format(img_list[ii][:-1]),
            results[0, :, :],
        )
        # np.save('../../cihp_prob_output/{}.npy'.format(img_list[ii][:-1]), prob_results[0, :, :])
        # pred_list.append(predictions.cpu())
        # label_list.append(labels.squeeze(1).cpu())
        # loss = criterion(outputs, labels, batch_average=True)
        # running_loss_ts += loss.item()

        # total_iou += utils.get_iou(predictions, labels)
    end_time = timeit.default_timer()
    print("time use for " + str(ii) + " is :" + str(end_time - start_time))

    # Eval
    pred_path = opts.output_path + "cihp_output/"
    eval_(
        pred_path=pred_path,
        gt_path=opts.gt_path,
        classes=opts.classes,
        txt_file=opts.txt_file,
    )
def main(opts):
    p = OrderedDict()  # Parameters to include in report
    p['trainBatch'] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p['nAveGrad'] = 1  # Average the gradient of several iterations
    p['lr'] = opts.lr  # Learning rate
    p['lrFtr'] = 1e-5
    p['lraspp'] = 1e-5
    p['lrpro'] = 1e-5
    p['lrdecoder'] = 1e-5
    p['lrother'] = 1e-5
    p['wd'] = 5e-4  # Weight decay
    p['momentum'] = 0.9  # Momentum
    p['epoch_size'] = opts.step  # How many epochs to change learning rate
    p['num_workers'] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = 'xception'  # Use xception or resnet as feature extractor,
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split('/')[-1]
    runs = glob.glob(os.path.join(save_dir_root, 'run_cihp', 'run_*'))
    for r in runs:
        run_id = int(r.split('_')[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    save_dir = os.path.join(save_dir_root, 'run_cihp', 'run_' + str(max_id))

    # Device
    if (opts.device == "gpu"):
        use_cuda = torch.cuda.is_available()
        if (use_cuda == True):
            device = torch.device("cuda")
            #torch.cuda.set_device(args.gpu_ids[0])
            print("実行デバイス :", device)
            print("GPU名 :", torch.cuda.get_device_name(device))
            print("torch.cuda.current_device() =", torch.cuda.current_device())
        else:
            print("can't using gpu.")
            device = torch.device("cpu")
            print("実行デバイス :", device)
    else:
        device = torch.device("cpu")
        print("実行デバイス :", device)

    # Network definition
    if backbone == 'xception':
        net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
            n_classes=opts.classes,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
        )
    elif backbone == 'resnet':
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = 'deeplabv3plus-' + backbone + '-voc' + datetime.now().strftime(
        '%b%d_%H-%M-%S')
    criterion = util.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        #net_.cuda()
        net_.to(device)

    # net load weights
    if not model_path == '':
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print('load pretrainedModel:', model_path)
    else:
        print('no pretrainedModel.')
    if not opts.loadmodel == '':
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print('load model:', opts.loadmodel)
    else:
        print('no model load !!!!!!!!')

    log_dir = os.path.join(
        save_dir, 'models',
        datetime.now().strftime('%b%d_%H-%M-%S') + '_' + socket.gethostname())
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text('load model', opts.loadmodel, 1)
    writer.add_text('setting', sys.argv[0], 1)

    if opts.freezeBN:
        net_.freeze_bn()

    print(net_)

    # Use the following optimizer
    optimizer = optim.SGD(net_.parameters(),
                          lr=p['lr'],
                          momentum=p['momentum'],
                          weight_decay=p['wd'])

    composed_transforms_tr = transforms.Compose([
        tr.RandomSized_new(opts.image_size),
        tr.Normalize_xception_tf(),
        tr.ToTensor_()
    ])

    composed_transforms_ts = transforms.Compose(
        [tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    composed_transforms_ts_flip = transforms.Compose(
        [tr.HorizontalFlip(),
         tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    #voc_train = cihp.VOCSegmentation(split='train', transform=composed_transforms_tr, flip=True)
    #voc_val = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts)
    #voc_val_flip = cihp.VOCSegmentation(split='val', transform=composed_transforms_ts_flip)
    voc_train = cihp.VOCSegmentation(base_dir="../../data/datasets/CIHP_4w",
                                     split='train',
                                     transform=composed_transforms_tr,
                                     flip=True)
    voc_val = cihp.VOCSegmentation(base_dir="../../data/datasets/CIHP_4w",
                                   split='val',
                                   transform=composed_transforms_ts)
    voc_val_flip = cihp.VOCSegmentation(base_dir="../../data/datasets/CIHP_4w",
                                        split='val',
                                        transform=composed_transforms_ts_flip)

    trainloader = DataLoader(voc_train,
                             batch_size=p['trainBatch'],
                             shuffle=True,
                             num_workers=p['num_workers'],
                             drop_last=True)
    testloader = DataLoader(voc_val,
                            batch_size=testBatch,
                            shuffle=False,
                            num_workers=p['num_workers'])
    testloader_flip = DataLoader(voc_val_flip,
                                 batch_size=testBatch,
                                 shuffle=False,
                                 num_workers=p['num_workers'])

    num_img_tr = len(trainloader)
    num_img_ts = len(testloader)
    running_loss_tr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")

    print("num_img_tr : ", num_img_tr)

    net = torch.nn.DataParallel(net_)
    train_graph, test_graph = get_graphs(opts, device)
    adj1, adj2, adj3 = train_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, nEpochs):
        start_time = timeit.default_timer()

        if epoch % p['epoch_size'] == p['epoch_size'] - 1:
            lr_ = util.lr_poly(p['lr'], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(net_.parameters(),
                                  lr=lr_,
                                  momentum=p['momentum'],
                                  weight_decay=p['wd'])
            writer.add_scalar('data/lr_', lr_, epoch)
            print('(poly lr policy) learning rate: ', lr_)

        net.train()
        for ii, sample_batched in enumerate(trainloader):

            inputs, labels = sample_batched['image'], sample_batched['label']
            # Forward-Backward of the mini-batch
            inputs, labels = Variable(inputs,
                                      requires_grad=True), Variable(labels)
            global_step += inputs.data.shape[0]

            if gpu_id >= 0:
                #inputs, labels = inputs.cuda(), labels.cuda()
                inputs, labels = inputs.to(device), labels.to(device)

            #print( "inputs.shape : ", inputs.shape )    # torch.Size([batch, 3, 512, 512])
            #print( "adj1.shape : ", adj1.shape )        # torch.Size([8, 1, 20, 20])
            #print( "adj2.shape : ", adj2.shape )        # torch.Size([8, 1, 20, 7])
            #print( "adj3.shape : ", adj3.shape )        # torch.Size([8, 1, 7, 7])
            outputs = net.forward(inputs, adj1, adj3, adj2)
            #print( "outputs.shape : ", outputs.shape )  # torch.Size([2, 20, 512, 512])

            loss = criterion(outputs, labels, batch_average=True)
            running_loss_tr += loss.item()

            # Print stuff
            if ii % num_img_tr == (num_img_tr - 1):
                running_loss_tr = running_loss_tr / num_img_tr
                writer.add_scalar('data/total_loss_epoch', running_loss_tr,
                                  epoch)
                print('[Epoch: %d, numImages: %5d]' %
                      (epoch, ii * p['trainBatch'] + inputs.data.shape[0]))
                print('Loss: %f' % running_loss_tr)
                running_loss_tr = 0
                stop_time = timeit.default_timer()
                print("Execution time: " + str(stop_time - start_time) + "\n")

            # Backward the averaged gradient
            loss /= p['nAveGrad']
            loss.backward()
            aveGrad += 1

            # Update the weights once in p['nAveGrad'] forward passes
            if aveGrad % p['nAveGrad'] == 0:
                writer.add_scalar('data/total_loss_iter', loss.item(),
                                  ii + num_img_tr * epoch)
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            # Show 10 * 3 images results each epoch
            if ii % (num_img_tr // 10) == 0:
                #            if ii % (num_img_tr // 4000) == 0:
                grid_image = make_grid(inputs[:3].clone().cpu().data,
                                       3,
                                       normalize=True)
                writer.add_image('Image', grid_image, global_step)
                grid_image = make_grid(util.decode_seg_map_sequence(
                    torch.max(outputs[:3], 1)[1].detach().cpu().numpy()),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                writer.add_image('Predicted label', grid_image, global_step)
                grid_image = make_grid(util.decode_seg_map_sequence(
                    torch.squeeze(labels[:3], 1).detach().cpu().numpy()),
                                       3,
                                       normalize=False,
                                       range=(0, 255))
                writer.add_image('Groundtruth label', grid_image, global_step)
            print('loss is ', loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(
                net_.state_dict(),
                os.path.join(save_dir, 'models',
                             modelName + '_epoch-' + str(epoch) + '.pth'))
            print("Save model at {}\n".format(
                os.path.join(save_dir, 'models',
                             modelName + '_epoch-' + str(epoch) + '.pth')))

        torch.cuda.empty_cache()

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            val_cihp(net_,
                     testloader=testloader,
                     testloader_flip=testloader_flip,
                     test_graph=test_graph,
                     epoch=epoch,
                     writer=writer,
                     criterion=criterion,
                     classes=opts.classes,
                     device=device)
        torch.cuda.empty_cache()
Пример #6
0
def main(opts):
    p = OrderedDict()  # Parameters to include in report
    p["trainBatch"] = opts.batch  # Training batch size
    testBatch = 1  # Testing batch size
    useTest = True  # See evolution of the test set when training
    nTestInterval = opts.testInterval  # Run on test set every nTestInterval epochs
    snapshot = 1  # Store a model every snapshot epochs
    p["nAveGrad"] = 1  # Average the gradient of several iterations
    p["lr"] = opts.lr  # Learning rate
    p["lrFtr"] = 1e-5
    p["lraspp"] = 1e-5
    p["lrpro"] = 1e-5
    p["lrdecoder"] = 1e-5
    p["lrother"] = 1e-5
    p["wd"] = 5e-4  # Weight decay
    p["momentum"] = 0.9  # Momentum
    p["epoch_size"] = opts.step  # How many epochs to change learning rate
    p["num_workers"] = opts.numworker
    model_path = opts.pretrainedModel
    backbone = "xception"  # Use xception or resnet as feature extractor,
    nEpochs = opts.epochs

    max_id = 0
    save_dir_root = os.path.join(os.path.dirname(os.path.abspath(__file__)))
    exp_name = os.path.dirname(os.path.abspath(__file__)).split("/")[-1]
    runs = glob.glob(os.path.join(save_dir_root, "run_cihp", "run_*"))
    for r in runs:
        run_id = int(r.split("_")[-1])
        if run_id >= max_id:
            max_id = run_id + 1
    save_dir = os.path.join(save_dir_root, "run_cihp", "run_" + str(max_id))

    # Network definition
    if backbone == "xception":
        net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(
            n_classes=opts.classes,
            os=16,
            hidden_layers=opts.hidden_layers,
            source_classes=7,
        )
    elif backbone == "resnet":
        # net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
        raise NotImplementedError
    else:
        raise NotImplementedError

    modelName = ("deeplabv3plus-" + backbone + "-voc" +
                 datetime.now().strftime("%b%d_%H-%M-%S"))
    criterion = util.cross_entropy2d

    if gpu_id >= 0:
        # torch.cuda.set_device(device=gpu_id)
        net_.cuda()

    # net load weights
    if not model_path == "":
        x = torch.load(model_path)
        net_.load_state_dict_new(x)
        print("load pretrainedModel:", model_path)
    else:
        print("no pretrainedModel.")
    if not opts.loadmodel == "":
        x = torch.load(opts.loadmodel)
        net_.load_source_model(x)
        print("load model:", opts.loadmodel)
    else:
        print("no model load !!!!!!!!")

    log_dir = os.path.join(
        save_dir,
        "models",
        datetime.now().strftime("%b%d_%H-%M-%S") + "_" + socket.gethostname(),
    )
    writer = SummaryWriter(log_dir=log_dir)
    writer.add_text("load model", opts.loadmodel, 1)
    writer.add_text("setting", sys.argv[0], 1)

    if opts.freezeBN:
        net_.freeze_bn()

    # Use the following optimizer
    optimizer = optim.SGD(net_.parameters(),
                          lr=p["lr"],
                          momentum=p["momentum"],
                          weight_decay=p["wd"])

    composed_transforms_tr = transforms.Compose(
        [tr.RandomSized_new(512),
         tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    composed_transforms_ts = transforms.Compose(
        [tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    composed_transforms_ts_flip = transforms.Compose(
        [tr.HorizontalFlip(),
         tr.Normalize_xception_tf(),
         tr.ToTensor_()])

    voc_train = cihp.VOCSegmentation(split="train",
                                     transform=composed_transforms_tr,
                                     flip=True)
    voc_val = cihp.VOCSegmentation(split="val",
                                   transform=composed_transforms_ts)
    voc_val_flip = cihp.VOCSegmentation(split="val",
                                        transform=composed_transforms_ts_flip)

    trainloader = DataLoader(
        voc_train,
        batch_size=p["trainBatch"],
        shuffle=True,
        num_workers=p["num_workers"],
        drop_last=True,
    )
    testloader = DataLoader(voc_val,
                            batch_size=testBatch,
                            shuffle=False,
                            num_workers=p["num_workers"])
    testloader_flip = DataLoader(voc_val_flip,
                                 batch_size=testBatch,
                                 shuffle=False,
                                 num_workers=p["num_workers"])

    num_img_tr = len(trainloader)
    num_img_ts = len(testloader)
    running_loss_tr = 0.0
    running_loss_ts = 0.0
    aveGrad = 0
    global_step = 0
    print("Training Network")

    net = torch.nn.DataParallel(net_)
    train_graph, test_graph = get_graphs(opts)
    adj1, adj2, adj3 = train_graph

    # Main Training and Testing Loop
    for epoch in range(resume_epoch, nEpochs):
        start_time = timeit.default_timer()

        if epoch % p["epoch_size"] == p["epoch_size"] - 1:
            lr_ = util.lr_poly(p["lr"], epoch, nEpochs, 0.9)
            optimizer = optim.SGD(net_.parameters(),
                                  lr=lr_,
                                  momentum=p["momentum"],
                                  weight_decay=p["wd"])
            writer.add_scalar("data/lr_", lr_, epoch)
            print("(poly lr policy) learning rate: ", lr_)

        net.train()
        for ii, sample_batched in enumerate(trainloader):

            inputs, labels = sample_batched["image"], sample_batched["label"]
            # Forward-Backward of the mini-batch
            inputs, labels = Variable(inputs,
                                      requires_grad=True), Variable(labels)
            global_step += inputs.data.shape[0]

            if gpu_id >= 0:
                inputs, labels = inputs.cuda(), labels.cuda()

            outputs = net.forward(inputs, adj1, adj3, adj2)

            loss = criterion(outputs, labels, batch_average=True)
            running_loss_tr += loss.item()

            # Print stuff
            if ii % num_img_tr == (num_img_tr - 1):
                running_loss_tr = running_loss_tr / num_img_tr
                writer.add_scalar("data/total_loss_epoch", running_loss_tr,
                                  epoch)
                print("[Epoch: %d, numImages: %5d]" %
                      (epoch, ii * p["trainBatch"] + inputs.data.shape[0]))
                print("Loss: %f" % running_loss_tr)
                running_loss_tr = 0
                stop_time = timeit.default_timer()
                print("Execution time: " + str(stop_time - start_time) + "\n")

            # Backward the averaged gradient
            loss /= p["nAveGrad"]
            loss.backward()
            aveGrad += 1

            # Update the weights once in p['nAveGrad'] forward passes
            if aveGrad % p["nAveGrad"] == 0:
                writer.add_scalar("data/total_loss_iter", loss.item(),
                                  ii + num_img_tr * epoch)
                optimizer.step()
                optimizer.zero_grad()
                aveGrad = 0

            # Show 10 * 3 images results each epoch
            if ii % (num_img_tr // 10) == 0:
                grid_image = make_grid(inputs[:3].clone().cpu().data,
                                       3,
                                       normalize=True)
                writer.add_image("Image", grid_image, global_step)
                grid_image = make_grid(
                    util.decode_seg_map_sequence(
                        torch.max(outputs[:3], 1)[1].detach().cpu().numpy()),
                    3,
                    normalize=False,
                    range=(0, 255),
                )
                writer.add_image("Predicted label", grid_image, global_step)
                grid_image = make_grid(
                    util.decode_seg_map_sequence(
                        torch.squeeze(labels[:3], 1).detach().cpu().numpy()),
                    3,
                    normalize=False,
                    range=(0, 255),
                )
                writer.add_image("Groundtruth label", grid_image, global_step)
            print("loss is ", loss.cpu().item(), flush=True)

        # Save the model
        if (epoch % snapshot) == snapshot - 1:
            torch.save(
                net_.state_dict(),
                os.path.join(save_dir, "models",
                             modelName + "_epoch-" + str(epoch) + ".pth"),
            )
            print("Save model at {}\n".format(
                os.path.join(save_dir, "models",
                             modelName + "_epoch-" + str(epoch) + ".pth")))

        torch.cuda.empty_cache()

        # One testing epoch
        if useTest and epoch % nTestInterval == (nTestInterval - 1):
            val_cihp(
                net_,
                testloader=testloader,
                testloader_flip=testloader_flip,
                test_graph=test_graph,
                epoch=epoch,
                writer=writer,
                criterion=criterion,
                classes=opts.classes,
            )
        torch.cuda.empty_cache()