mylog.write('********' + '\n')
    mylog.write('epoch:' + str(epoch) + '    time:' + str(int(time() - tic)) +
                '\n')
    mylog.write('train_loss:' + str(train_epoch_loss) + '\n')
    mylog.write('SHAPE:' + str(SHAPE) + '\n')
    print('********')
    print('epoch:', epoch, '    time:', int(time() - tic))
    print('train_loss:', train_epoch_loss)
    print('SHAPE:', SHAPE)

    if train_epoch_loss >= train_epoch_best_loss:
        no_optim += 1
    else:
        no_optim = 0
        train_epoch_best_loss = train_epoch_loss
        solver.save('weights/' + NAME + '.th')
    if no_optim > 6:
        print(mylog, 'early stop at %d epoch' % epoch)
        print('early stop at %d epoch' % epoch)
        break
    if no_optim > 3:
        if solver.old_lr < 5e-7:
            break
        solver.load('weights/' + NAME + '.th')
        solver.update_lr(5.0, factor=True, mylog=mylog)
    mylog.flush()

mylog.write('Finish!')
print('Finish!')
mylog.close()
Beispiel #2
0
def train_operation(train_paras):
    sat_dir = train_paras["image_dir"]
    lab_dir = train_paras["gt_dir"]
    train_id = train_paras["train_id"]
    logfile_dir = train_paras["logfile_dir"]
    model_dir = train_paras["model_dir"]
    model_name = train_paras["model_name"]
    learning_rate = train_paras["learning_rate"]

    imagelist = os.listdir(sat_dir)

    trainlist = list(map(lambda x: x[:-8], imagelist))
    # trainlist = trainlist[:1000]
    BATCHSIZE_PER_CARD = 2
    solver = MyFrame(DUNet, learning_rate, model_name)
    # solver = MyFrame(Unet, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

    dataset = ImageFolder(trainlist, sat_dir, lab_dir)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=0)

    mylog = open(logfile_dir + model_name + '.log', 'w')
    print("**************" + model_name + "******************", file=mylog)
    print("**************" + model_name + "******************")
    print("current train id:{}".format(train_id), file=mylog)
    print("current train id:{}".format(train_id))
    print("batch size:{}".format(batchsize), file=mylog)
    print("total images: {}".format(len(trainlist)))
    print("total images: {}".format(len(trainlist)), file=mylog)

    tic = time()
    no_optim = 0
    total_epoch = train_paras["total_epoch"]
    train_epoch_best_loss = 100.

    # solver.load('weights/dlinknet_new_lr_decoder.th')
    # print('* load existing model *')

    epoch_iter = 0
    print("learning rate is {}".format(learning_rate), file=mylog)
    print("Precompute weight for 5 epoches", file=mylog)
    print("Precompute weight for 5 epoches")
    save_tensorboard_iter = 5
    pre_compute_flag = 1
    # solver.load(model_dir + model_name + '.th')
    # pretrain W
    for epoch in range(1, 6):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        if epoch < 5:
            no_optim = 0
            t = 0
            for img, mask in data_loader_iter:
                t += 1
                solver.set_input(img, mask)
                solver.pre_compute_W(t)
        print('********', file=mylog)
        print('pre-train W::',
              epoch,
              '    time:',
              int(time() - tic),
              file=mylog)
        print('********')
        print('pre-train W:', epoch, '    time:', int(time() - tic))

    print("pretrain is OVER")
    print("pretrain is OVER", file=mylog)

    step_update = False
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        for img, mask in data_loader_iter:
            imgs = solver.set_input(img, mask)
            train_loss = solver.optimize(pre_compute_flag)
            pre_compute_flag = 0
            train_epoch_loss += train_loss
        train_epoch_loss /= len(data_loader_iter)
        print('********', file=mylog)
        print('epoch:', epoch, '    time:', int(time() - tic), file=mylog)
        print('train_loss:', train_epoch_loss, file=mylog)
        print('SHAPE:', SHAPE, file=mylog)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', SHAPE)
        if epoch % save_tensorboard_iter == 1:
            solver.update_tensorboard(epoch)
        # imgs=imgs.to(torch.device("cpu"))
        # solver.writer.add_graph(solver.model,imgs)
        print("train best loss is {}".format(train_epoch_best_loss))
        print("train best loss is {}".format(train_epoch_best_loss),
              file=mylog)
        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save(model_dir + model_name + '.th')
        if no_optim > 6:
            print('early stop at %d epoch' % epoch, file=mylog)
            print('early stop at %d epoch' % epoch)
            break
        elif no_optim > 3:
            step_update = True
            solver.update_lr(5.0, factor=True, mylog=mylog)
            print("update lr by ratio 0.5")
        elif no_optim > 2:
            if solver.old_lr < 5e-7:
                break
            solver.load(model_dir + model_name + '.th')
            # solver.update_lr(5.0, factor=True, mylog=mylog)
            if step_update:
                solver.update_lr(5.0, factor=True, mylog=mylog)
                step_update = False
            else:
                solver.update_lr_poly(epoch, total_epoch, mylog,
                                      total_epoch / 40)
        if not step_update:
            solver.update_lr_poly(epoch, total_epoch, mylog, total_epoch / 40)
        mylog.flush()

    solver.close_tensorboard()
    print('*********************Finish!***********************', file=mylog)
    print('Finish!')
    mylog.close()
Beispiel #3
0
    print('train_loss:', train_epoch_loss, file=mylog)
    print('SHAPE:', SHAPE, file=mylog)
    print('********')
    print('epoch:', epoch, '    time:', int(time() - tic))
    print('train_loss:', train_epoch_loss)
    print('SHAPE:', SHAPE)

    if (epoch % 20 == 0 and epoch != 0):
        solver.save('weights/' + NAME + '/' + NAME + str(train_epoch_loss) +
                    '.th')

    if train_epoch_loss >= train_epoch_best_loss:
        no_optim += 1
    else:
        no_optim = 0
        train_epoch_best_loss = train_epoch_loss
        solver.save('weights/' + NAME + '.th')
    if no_optim > 20:
        print('early stop at %d epoch' % epoch, file=mylog)
        print('early stop at %d epoch' % epoch)
        break
    if no_optim > 10:
        if solver.old_lr < 5e-7:
            break
        solver.load('weights/' + NAME + '.th')
        solver.update_lr(0.8, factor=True, mylog=mylog)
    mylog.flush()

print('Finish!', file=mylog)
print('Finish!')
mylog.close()
Beispiel #4
0
def vessel_main():
    SHAPE = (448, 448)
    # ROOT = 'dataset/RIM-ONE/'
    ROOT = './dataset/DRIVE'
    NAME = 'log01_dink34-UNet' + ROOT.split('/')[-1]
    BATCHSIZE_PER_CARD = 8

    # net = UNet(n_channels=3, n_classes=2)

    viz = Visualizer(env="Vessel_Unet_from_scratch")

    solver = MyFrame(UNet, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

    dataset = ImageFolder(root_path=ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()
    no_optim = 0
    total_epoch = 300
    train_epoch_best_loss = 10000.
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0

        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)

            train_loss, pred = solver.optimize()

            train_epoch_loss += train_loss

            index = index + 1

            # if index % 10 == 0:
            #     # train_epoch_loss /= index
            #     # viz.plot(name='loss', y=train_epoch_loss)
            #     show_image = (img + 1.6) / 3.2 * 255.
            #     viz.img(name='images', img_=show_image[0, :, :, :])
            #     viz.img(name='labels', img_=mask[0, :, :, :])
            #     viz.img(name='prediction', img_=pred[0, :, :, :])

        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', SHAPE)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', SHAPE)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('./weights/' + NAME + '.th')
        if no_optim > 20:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > 15:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
Beispiel #5
0
        # img_out = vutils.make_grid(pre,nrow=4,normalize=True)#必须是tensor
        # write.add_image('predict_out',img_out,allstep)#必须是三个通道的

        #可视化损失函数输出
        train_epoch_loss += train_loss#所有的loss和
        write.add_scalar('train_loss',train_loss,allstep)
        # #可视化网络参数直方图感觉影响速度
        # for name,param in solver.net.named_parameters():
        #     write.add_histogram(name,param.data.cpu().numpy(),allstep)
    train_epoch_loss /= len(train_load)#平均loss
    print('********')
    print('epoch:',epoch,'time:',int(time()-tic)/60)
    print('train_loss:',train_epoch_loss)

    if train_epoch_loss >= train_epoch_best_loss:
        no_optim += 1
    else:
        no_optim = 0
        train_epoch_best_loss = train_epoch_loss #保留结果
        solver.save(modefiles)
    if no_optim > 6:
        print('early stop at %d epoch' % epoch)
        break
    if no_optim > 3:
        if solver.old_lr < 5e-7:
            break
        solver.load('weights/'+NAME+'.th')
        solver.update_lr(5.0, factor = True)


Beispiel #6
0
def CE_Net_Train(train_i=0):

    NAME = 'fold' + str(i + 1) + '_6CE-Net' + Constants.ROOT.split('/')[-1]

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD  #4

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    txt_train = 'fold' + str(train_i + 1) + '_train.csv'
    txt_test = 'fold' + str(train_i + 1) + '_test.csv'
    dataset_train = MyDataset(txt_path=txt_train,
                              transform=transforms.ToTensor(),
                              target_transform=transforms.ToTensor())
    dataset_test = MyDataset(txt_path=txt_test,
                             transform=transforms.ToTensor(),
                             target_transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batchsize=batchsize,
                                               shuffle=True,
                                               num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset,
                                              batchsize=batchsize,
                                              shuffle=False,
                                              num_workers=2)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH  # 300
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS  # 10000
    best_test_score = 0
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(train_loader)
        data_loader_test = iter(test_loader)
        train_epoch_loss = 0
        index = 0

        tic = time()

        # train
        for img, mask in data_loader_iter:
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # test
        test_sen = 0
        test_ppv = 0
        test_score = 0
        for img, mask in data_loader_test:
            solver.set_input(img, mask)
            pre_mask, _ = solver.test_batch()
            test_score += dice_coeff(y_test, pre_mask, False)
            test_sen += sensitive(y_test, pre_mask)
            # test_sen = test_sen.cpu().data.numpy()
            test_ppv += positivepv(y_test, pre_mask)
    # test_ppv = test_ppv.cpu().data.numpy()
        print(test_sen / len(data_loader_test),
              test_ppv / len(data_loader_test),
              test_score / len(data_loader_test))
        # solver.set_input(x_test, y_test)
        # pre_mask, _ = solver.test_batch()
        # test_score = dice_coeff(y_test, pre_mask, False)
        # test_sen = sensitive(y_test, pre_mask)
        # test_sen = test_sen.cpu().data.numpy()
        # test_ppv = positivepv(y_test, pre_mask)
        # test_ppv = test_ppv.cpu().data.numpy()
        # print('111111111111111111111',type(test_score))

        # # show the original images, predication and ground truth on the visdom.
        # show_image = (img + 1.6) / 3.2 * 255.
        # viz.img(name='images', img_=show_image[0, :, :, :])
        # viz.img(name='labels', img_=mask[0, :, :, :])
        # viz.img(name='prediction', img_=pred[0, :, :, :])

        if test_score > best_test_score:
            print('1. the dice score up to ', test_score, 'from ',
                  best_test_score, 'saving the model')
            best_test_score = test_score
            solver.save('./weights/' + NAME + '.th')

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        # print(mylog, '********')
        print('epoch:',
              epoch,
              '    time:',
              int(time() - tic),
              'train_loss:',
              train_epoch_loss.cpu().data.numpy(),
              file=mylog,
              flush=True)
        print('test_dice_loss: ',
              test_score,
              'test_sen: ',
              test_sen,
              'test_ppv: ',
              test_ppv,
              'best_score is ',
              best_test_score,
              file=mylog,
              flush=True)

        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic), 'train_loss:',
              train_epoch_loss.cpu().data.numpy())
        print('test_dice_score: ', test_score, 'test_sen: ', test_sen,
              'test_ppv: ', test_ppv, 'best_score is ', best_test_score)
        # print('train_loss:', train_epoch_loss)
        # print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            # solver.save('./weights/' + NAME + '.th')
        # if no_optim > Constants.NUM_EARLY_STOP:
        #     print(mylog, 'early stop at %d epoch' % epoch)
        #     print('early stop at %d epoch' % epoch)
        #     break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            if solver.old_lr > 5e-4:
                solver.load('./weights/' + NAME + '.th')
                solver.update_lr(1.5, factor=True, mylog=mylog)

    print('Finish!', file=mylog, flush=True)
    print('Finish!')
    mylog.close()
Beispiel #7
0
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('./weights/' + NAME + '.th')
        if no_optim > Constants.NUM_EARLY_STOP:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    print("count", Constants.BATCHSIZE_PER_CARD)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD
    print("batchsize", batchsize)

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='Cell')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    dataset_val = ImageFolder(root_path='./test_data/DRIVE_dot_dash_training',
                              datasets='Cell')
    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=8,
                                                  shuffle=True,
                                                  num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            # solver.load('./weights/' + NAME + '.th')
            # print("iterating the dataloader")
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        torchvision.utils.save_image(img[0, :, :, :],
                                     "images/image_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(mask[0, :, :, :],
                                     "images/mask_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(pred[0, :, :, :],
                                     "images/pred_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)

        # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8)
        # x = show_image[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'RGB')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) +  '.jpg')

        # x = mask[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'L')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) +  '.jpg')

        # x = pred[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy())
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) +  '.jpg')
        # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png")
        # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0)
        # x = mask[0,:,:,:]
        # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png")
        # x = pred[0,:,:,:]
        # print(x.shape)
        # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8'))

        # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png")
        print("saving images")
        print("Train_loss_for_all ", train_epoch_loss)
        print("length of (data_loader_iter) ", len(data_loader_iter))
        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            print("Saving the Weights")
            solver.save('./weights/' + NAME + '.th')
            if epoch % 100 == 0:
                solver.save('./weights/' + NAME + str(epoch) + '.th')
        if no_optim > Constants.NUM_EARLY_STOP:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

        if (epoch % 1 == 0):
            # validation save image
            print('in VALIDATION')
            # for
            data_loader_iter_val = iter(data_loader_val)
            train_epoch_loss = 0
            index = 0

            for img, mask in data_loader_iter_val:
                # solver.load('./weights/' + NAME + '.th')
                solver.set_input(img, mask)

                train_loss, pred = solver.optimize_test()
                train_epoch_loss += train_loss
                index = index + 1
                # torchvision.utils.save_image(img[0, :, :, :], "test_data/results2/image_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
                # torchvision.utils.save_image(mask[0, :, :, :], "test_data/results2/mask_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
                # torchvision.utils.save_image(pred[0, :, :, :], "test_data/results2/pred_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
            print("Train_loss_for_all ", train_epoch_loss)
            print("length of (data_loader_iter_val) ",
                  len(data_loader_iter_val))
            print(train_epoch_loss / len(data_loader_iter_val))
            print('++++++++++++++++++++++++++++++++++')
            # show the original images, predication and ground truth on the visdom.
            # show_image = (img + 1.6) / 3.2 * 255.
            # viz.img(name='images', img_=show_image[0, :, :, :])
            # viz.img(name='labels', img_=mask[0, :, :, :])
            # viz.img(name='prediction', img_=pred[0, :, :, :])

            torchvision.utils.save_image(img[0, :, :, :],
                                         "test_data/results4/image_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)
            torchvision.utils.save_image(mask[0, :, :, :],
                                         "test_data/results4/mask_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)
            torchvision.utils.save_image(pred[0, :, :, :],
                                         "test_data/results4/pred_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)

            # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8)
            # x = show_image[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'RGB')(x)
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) +  '.jpg')

            # x = mask[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'L')(x)
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) +  '.jpg')

            # x = pred[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy())
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) +  '.jpg')
            # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png")
            # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0)
            # x = mask[0,:,:,:]
            # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png")
            # x = pred[0,:,:,:]
            # print(x.shape)
            # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8'))

            # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png")
            # print("saving images")

            # train_epoch_loss = train_epoch_loss/len(data_loader_iter)
            # print(mylog, '********')
            # print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
            # print(mylog, 'train_loss:', train_epoch_loss)
            # print(mylog, 'SHAPE:', Constants.Image_size)
            # print('********')
            # print('epoch:', epoch, '    time:', int(time() - tic))
            # print('train_loss:', train_epoch_loss)
            # print('SHAPE:', Constants.Image_size)

            # if train_epoch_loss >= train_epoch_best_loss:
            #     no_optim += 1
            # else:
            #     no_optim = 0
            #     train_epoch_best_loss = train_epoch_loss
            #     solver.save('./weights/' + NAME + '.th')
            # if no_optim > Constants.NUM_EARLY_STOP:
            #     print(mylog, 'early stop at %d epoch' % epoch)
            #     print('early stop at %d epoch' % epoch)
            #     break
            # if no_optim > Constants.NUM_UPDATE_LR:
            #     if solver.old_lr < 5e-7:
            #         break
            #     solver.load('./weights/' + NAME + '.th')
            #     solver.update_lr(2.0, factor=True, mylog=mylog)
            # mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()