Ejemplo n.º 1
0
def train(train_i=0):
	NAME = 'D2F5_fold'+str(train_i+1)+'_FPN.th'  
	print(NAME)

	batchsize = 4

	txt_train = 'N5fold'+str(train_i+1)+'_train.csv'
	txt_test = 'N5fold'+str(train_i+1)+'_test.csv'
	dataset_train = MyDataset(root='/home/wangke/ultrasound_data2/', txt_path=txt_train, transform=transforms.ToTensor(), target_transform=transforms.ToTensor())
	dataset_test = MyDataset(root='/home/wangke/ultrasound_data2/', txt_path=txt_test, transform=transforms.ToTensor(), target_transform=transforms.ToTensor())
	train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=batchsize, shuffle=True, num_workers=2)
	test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=batchsize, shuffle=False, num_workers=2, drop_last=True)

	mylog = open('logs/'+NAME+'.log', 'w')


	# model = FPN_Net(1, 1)
	# summary(model)

	slover = MyFrame(FPN_Net, dice_bce_loss, 2e-4)

	total_epoch = 100
	no_optim = 0
	train_epoch_best_loss = 10000
	best_test_score = 0
	for epoch in range(1, total_epoch+1):
		data_loader_iter = iter(train_loader)
		data_loader_test = iter(test_loader)
		train_epoch_loss = 0
		index = 0

		tic = time()

		train_score = 0
		for img, mask in data_loader_iter:
			slover.set_input(img, mask)
			train_loss, pred = slover.optimize()
			train_score += dice_coeff(mask, pred.cpu().data, False)
			train_epoch_loss +=train_loss 
			index +=1

		test_sen = 0
		test_ppv = 0
		test_score = 0
		test_acc = 0
		test_spe = 0
		test_f1s = 0
		for img, mask in data_loader_test:
			slover.set_input(img, mask)
			pre_mask, _ = slover.test_batch()
			test_score += dice_coeff(mask, pre_mask, False)
			test_sen += sensitive(mask, pre_mask)
			test_ppv += precision(mask, pre_mask)
			test_acc += accuracy(mask, pre_mask)
			test_spe += specificity(mask, pre_mask)
			test_f1s += f1_score(mask, pre_mask)

		test_sen /= len(data_loader_test)
		test_ppv /= len(data_loader_test)
		test_score /= len(data_loader_test)
		test_acc /= len(data_loader_test)
		test_spe /= len(data_loader_test)
		test_f1s /= len(data_loader_test)

		if test_score>best_test_score:
			print('1. the dice score up to ', test_score, 'from ', best_test_score, 'saving the model', file=mylog, flush=True)
			print('1. the dice score up to ', test_score, 'from ', best_test_score, 'saving the model')
			best_test_score = test_score
			slover.save('./weights/'+NAME+'.th')

		train_epoch_loss = train_epoch_loss/len(data_loader_iter)
		train_score = train_score/len(data_loader_iter)
		print('epoch:', epoch, '    time:', int(time() - tic), 'train_loss:', train_epoch_loss.cpu().data.numpy(), 'train_score:', train_score, file=mylog, flush=True)
		print('test_dice_loss: ', test_score, 'test_sen: ', test_sen.numpy(), 'test_ppv: ', test_ppv.numpy(), 'test_acc: ', test_acc.numpy(), 'test_spe: ', test_spe.numpy(), 'test_f1s: ', test_f1s.numpy(), 'best_score is ', best_test_score, file=mylog, flush=True)
		
		print('********')
		print('epoch:', epoch, '    time:', int(time() - tic), 'train_loss:', train_epoch_loss.cpu().data.numpy(), 'train_score:', train_score)
		print('test_dice_loss: ', test_score, 'test_sen: ', test_sen.numpy(), 'test_ppv: ', test_ppv.numpy(), 'test_acc: ', test_acc.numpy(), 'test_spe: ', test_spe.numpy(), 'test_f1s: ', test_f1s.numpy(), 'best_score is ', best_test_score)

		if train_epoch_loss >= train_epoch_best_loss:
			no_optim +=1
		else:
			no_optim =0
			train_epoch_best_loss = train_epoch_loss

	print('Finish!', file=mylog, flush=True)
	print('Finish!')
	mylog.close()
data_loader = torch.utils.data.DataLoader(dataset,
                                          batch_size=batchsize,
                                          shuffle=True,
                                          num_workers=4)

tic = time()
for epoch in range(1, total_epoch + 1):
    data_loader_iter = iter(data_loader)
    train_epoch_loss = 0
    print('---------- Epoch:' + str(epoch) + ' ----------')
    print('Train:')
    for img, mask in tqdm(data_loader_iter,
                          ncols=20,
                          total=len(data_loader_iter)):
        solver.set_input(img, mask)
        train_loss = solver.optimize()
        train_epoch_loss += train_loss
    train_epoch_loss /= len(data_loader_iter)

    mylog.write('********' + '\n')
    mylog.write('epoch:' + str(epoch) + '    time:' + str(int(time() - tic)) +
                '\n')
    mylog.write('train_loss:' + str(train_epoch_loss) + '\n')
    mylog.write('SHAPE:' + str(SHAPE) + '\n')
    print('********')
    print('epoch:', epoch, '    time:', int(time() - tic))
    print('train_loss:', train_epoch_loss)
    print('SHAPE:', SHAPE)

    if train_epoch_loss >= train_epoch_best_loss:
        no_optim += 1
Ejemplo n.º 3
0
def vessel_main():
    SHAPE = (448, 448)
    # ROOT = 'dataset/RIM-ONE/'
    ROOT = './dataset/DRIVE'
    NAME = 'log01_dink34-UNet' + ROOT.split('/')[-1]
    BATCHSIZE_PER_CARD = 8

    # net = UNet(n_channels=3, n_classes=2)

    viz = Visualizer(env="Vessel_Unet_from_scratch")

    solver = MyFrame(UNet, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

    dataset = ImageFolder(root_path=ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()
    no_optim = 0
    total_epoch = 300
    train_epoch_best_loss = 10000.
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0

        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)

            train_loss, pred = solver.optimize()

            train_epoch_loss += train_loss

            index = index + 1

            # if index % 10 == 0:
            #     # train_epoch_loss /= index
            #     # viz.plot(name='loss', y=train_epoch_loss)
            #     show_image = (img + 1.6) / 3.2 * 255.
            #     viz.img(name='images', img_=show_image[0, :, :, :])
            #     viz.img(name='labels', img_=mask[0, :, :, :])
            #     viz.img(name='prediction', img_=pred[0, :, :, :])

        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', SHAPE)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', SHAPE)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('./weights/' + NAME + '.th')
        if no_optim > 20:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > 15:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
Ejemplo n.º 4
0
def train_operation(train_paras):
    sat_dir = train_paras["image_dir"]
    lab_dir = train_paras["gt_dir"]
    train_id = train_paras["train_id"]
    logfile_dir = train_paras["logfile_dir"]
    model_dir = train_paras["model_dir"]
    model_name = train_paras["model_name"]
    learning_rate = train_paras["learning_rate"]

    imagelist = os.listdir(sat_dir)

    trainlist = list(map(lambda x: x[:-8], imagelist))
    # trainlist = trainlist[:1000]
    BATCHSIZE_PER_CARD = 2
    solver = MyFrame(DUNet, learning_rate, model_name)
    # solver = MyFrame(Unet, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD

    dataset = ImageFolder(trainlist, sat_dir, lab_dir)
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=0)

    mylog = open(logfile_dir + model_name + '.log', 'w')
    print("**************" + model_name + "******************", file=mylog)
    print("**************" + model_name + "******************")
    print("current train id:{}".format(train_id), file=mylog)
    print("current train id:{}".format(train_id))
    print("batch size:{}".format(batchsize), file=mylog)
    print("total images: {}".format(len(trainlist)))
    print("total images: {}".format(len(trainlist)), file=mylog)

    tic = time()
    no_optim = 0
    total_epoch = train_paras["total_epoch"]
    train_epoch_best_loss = 100.

    # solver.load('weights/dlinknet_new_lr_decoder.th')
    # print('* load existing model *')

    epoch_iter = 0
    print("learning rate is {}".format(learning_rate), file=mylog)
    print("Precompute weight for 5 epoches", file=mylog)
    print("Precompute weight for 5 epoches")
    save_tensorboard_iter = 5
    pre_compute_flag = 1
    # solver.load(model_dir + model_name + '.th')
    # pretrain W
    for epoch in range(1, 6):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        if epoch < 5:
            no_optim = 0
            t = 0
            for img, mask in data_loader_iter:
                t += 1
                solver.set_input(img, mask)
                solver.pre_compute_W(t)
        print('********', file=mylog)
        print('pre-train W::',
              epoch,
              '    time:',
              int(time() - tic),
              file=mylog)
        print('********')
        print('pre-train W:', epoch, '    time:', int(time() - tic))

    print("pretrain is OVER")
    print("pretrain is OVER", file=mylog)

    step_update = False
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        for img, mask in data_loader_iter:
            imgs = solver.set_input(img, mask)
            train_loss = solver.optimize(pre_compute_flag)
            pre_compute_flag = 0
            train_epoch_loss += train_loss
        train_epoch_loss /= len(data_loader_iter)
        print('********', file=mylog)
        print('epoch:', epoch, '    time:', int(time() - tic), file=mylog)
        print('train_loss:', train_epoch_loss, file=mylog)
        print('SHAPE:', SHAPE, file=mylog)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', SHAPE)
        if epoch % save_tensorboard_iter == 1:
            solver.update_tensorboard(epoch)
        # imgs=imgs.to(torch.device("cpu"))
        # solver.writer.add_graph(solver.model,imgs)
        print("train best loss is {}".format(train_epoch_best_loss))
        print("train best loss is {}".format(train_epoch_best_loss),
              file=mylog)
        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save(model_dir + model_name + '.th')
        if no_optim > 6:
            print('early stop at %d epoch' % epoch, file=mylog)
            print('early stop at %d epoch' % epoch)
            break
        elif no_optim > 3:
            step_update = True
            solver.update_lr(5.0, factor=True, mylog=mylog)
            print("update lr by ratio 0.5")
        elif no_optim > 2:
            if solver.old_lr < 5e-7:
                break
            solver.load(model_dir + model_name + '.th')
            # solver.update_lr(5.0, factor=True, mylog=mylog)
            if step_update:
                solver.update_lr(5.0, factor=True, mylog=mylog)
                step_update = False
            else:
                solver.update_lr_poly(epoch, total_epoch, mylog,
                                      total_epoch / 40)
        if not step_update:
            solver.update_lr_poly(epoch, total_epoch, mylog, total_epoch / 40)
        mylog.flush()

    solver.close_tensorboard()
    print('*********************Finish!***********************', file=mylog)
    print('Finish!')
    mylog.close()
tic = time()

writer = SummaryWriter('runs/'+ config.EXPNAME)

#TODO  只存储rank0的模型和输出即可


for epoch in range(1, total_epoch + 1):
    print('---epoch start-----')
    #data_loader_iter = iter(data_loader)
    train_epoch_loss = 0
    for img, mask in data_loader:

        solver.set_input(img, mask)
        train_loss = solver.optimize(config)
        train_epoch_loss += train_loss
    train_epoch_loss /= len(data_loader)
    if args.local_rank == 0:
        print('********',file=mylog)
        print('epoch:'+ str(epoch+config.TRAIN.RESUME_START) + '    time:'+ str(time()-tic), file=mylog)
        print('train_loss: {}'.format(train_epoch_loss), file=mylog)
        writer.add_scalar('scalar/train',train_epoch_loss,epoch+config.TRAIN.RESUME_START)
        print('********')
        print('epoch:'+str(epoch+config.TRAIN.RESUME_START)+'    time:'+ str(time()-tic))
        print('train_loss: {}'.format(train_epoch_loss))
        split = False
        if epoch%10 == 0:

            BATCHSIZE_PER_CARD = config.TEST.BATCH_SIZE_PER_GPU
            label_list = config.TEST.LABEL_LIST
Ejemplo n.º 6
0

#训练网络
#可视化脚本tensorboard --logdir=weights --host=127.0.0.1
tic = time()
no_optim = 0
total_epoch = 20#训练次数
train_epoch_best_loss = 100.
for epoch in tqdm.tqdm(range(1, total_epoch + 1)):
    train_epoch_loss = 0
    for i,(img,mask) in enumerate(train_load):
        #总循环次数
        allstep=epoch*len(train_load)+i+1
        solver.set_input(img, mask)
        #网络训练,返回loss和网络输出
        train_loss,netout = solver.optimize()
        
        # #可视化训练数据
        # img_x = vutils.make_grid(img,nrow=4,normalize=True)
        # write.add_image('train_images',img_x,allstep)
        
        # #可视化标签
        # mask_pic=IamColor(mask)
        # mask_pic = torch.from_numpy(mask_pic)
        # mask_pic = vutils.make_grid(mask_pic,nrow=4,normalize=True)
        # write.add_image('label_images',mask_pic,allstep)
        
        # #可视化网络输出
        # pre = torch.argmax(netout.cpu(),1)
        # img_out = np.zeros(pre.shape + (3,))
        # for ii in range(num_class):
Ejemplo n.º 7
0
def CE_Net_Train(train_i=0):

    NAME = 'fold' + str(i + 1) + '_6CE-Net' + Constants.ROOT.split('/')[-1]

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD  #4

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    txt_train = 'fold' + str(train_i + 1) + '_train.csv'
    txt_test = 'fold' + str(train_i + 1) + '_test.csv'
    dataset_train = MyDataset(txt_path=txt_train,
                              transform=transforms.ToTensor(),
                              target_transform=transforms.ToTensor())
    dataset_test = MyDataset(txt_path=txt_test,
                             transform=transforms.ToTensor(),
                             target_transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(dataset,
                                               batchsize=batchsize,
                                               shuffle=True,
                                               num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset,
                                              batchsize=batchsize,
                                              shuffle=False,
                                              num_workers=2)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH  # 300
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS  # 10000
    best_test_score = 0
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(train_loader)
        data_loader_test = iter(test_loader)
        train_epoch_loss = 0
        index = 0

        tic = time()

        # train
        for img, mask in data_loader_iter:
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # test
        test_sen = 0
        test_ppv = 0
        test_score = 0
        for img, mask in data_loader_test:
            solver.set_input(img, mask)
            pre_mask, _ = solver.test_batch()
            test_score += dice_coeff(y_test, pre_mask, False)
            test_sen += sensitive(y_test, pre_mask)
            # test_sen = test_sen.cpu().data.numpy()
            test_ppv += positivepv(y_test, pre_mask)
    # test_ppv = test_ppv.cpu().data.numpy()
        print(test_sen / len(data_loader_test),
              test_ppv / len(data_loader_test),
              test_score / len(data_loader_test))
        # solver.set_input(x_test, y_test)
        # pre_mask, _ = solver.test_batch()
        # test_score = dice_coeff(y_test, pre_mask, False)
        # test_sen = sensitive(y_test, pre_mask)
        # test_sen = test_sen.cpu().data.numpy()
        # test_ppv = positivepv(y_test, pre_mask)
        # test_ppv = test_ppv.cpu().data.numpy()
        # print('111111111111111111111',type(test_score))

        # # show the original images, predication and ground truth on the visdom.
        # show_image = (img + 1.6) / 3.2 * 255.
        # viz.img(name='images', img_=show_image[0, :, :, :])
        # viz.img(name='labels', img_=mask[0, :, :, :])
        # viz.img(name='prediction', img_=pred[0, :, :, :])

        if test_score > best_test_score:
            print('1. the dice score up to ', test_score, 'from ',
                  best_test_score, 'saving the model')
            best_test_score = test_score
            solver.save('./weights/' + NAME + '.th')

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        # print(mylog, '********')
        print('epoch:',
              epoch,
              '    time:',
              int(time() - tic),
              'train_loss:',
              train_epoch_loss.cpu().data.numpy(),
              file=mylog,
              flush=True)
        print('test_dice_loss: ',
              test_score,
              'test_sen: ',
              test_sen,
              'test_ppv: ',
              test_ppv,
              'best_score is ',
              best_test_score,
              file=mylog,
              flush=True)

        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic), 'train_loss:',
              train_epoch_loss.cpu().data.numpy())
        print('test_dice_score: ', test_score, 'test_sen: ', test_sen,
              'test_ppv: ', test_ppv, 'best_score is ', best_test_score)
        # print('train_loss:', train_epoch_loss)
        # print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            # solver.save('./weights/' + NAME + '.th')
        # if no_optim > Constants.NUM_EARLY_STOP:
        #     print(mylog, 'early stop at %d epoch' % epoch)
        #     print('early stop at %d epoch' % epoch)
        #     break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            if solver.old_lr > 5e-4:
                solver.load('./weights/' + NAME + '.th')
                solver.update_lr(1.5, factor=True, mylog=mylog)

    print('Finish!', file=mylog, flush=True)
    print('Finish!')
    mylog.close()
Ejemplo n.º 8
0
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='DRIVE')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            solver.save('./weights/' + NAME + '.th')
        if no_optim > Constants.NUM_EARLY_STOP:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()
Ejemplo n.º 9
0
train_loader, valid_loader, testa_loader = generate_loader(train_bs=32)
opt = Configure()
net = BiGruCNN
loss_func = CrossEntropyLoss(size_average=True)
solver = MyFrame(net=net, loss=loss_func, opt=opt, lr=1e-3, device=device)
solver.load(path)
# solver.net.embedding.weight.requires_grad = True

no_optim_round = 0
for epoch in range(total_epochs):
    # train
    solver.train_mode()
    train_loss = 0.
    for X, y in tqdm(train_loader):
        solver.set_input(X, y)
        step_loss = solver.optimize()
        train_loss += step_loss
    train_epoch_loss = train_loss / len(train_loader)
    logging.info('epoch  %d train_loss:%.5f' % (epoch + 1, train_epoch_loss))
    # elval
    valid_score = solver.eval(valid_loader)

    # save
    if valid_score > valid_best_score:
        logging.info(
            'epoch %d valid_score improve %.5f >>>>>>>>> %.5f save model at : %s'
            % (epoch + 1, valid_best_score, valid_score, path))
        solver.save(path)
        valid_best_score = valid_score
        no_optim_round = 0
Ejemplo n.º 10
0
def train(train_i=0):
    NAME = 'fold' + str(train_i + 1) + '_deeplabv3_plus.th'
    # slover = MyFrame(FSP_Net, dice_bce_loss, 2e-4)
    slover = MyFrame(DeepLabV3Plus, dice_bce_loss, 5e-4)

    batchsize = 4

    txt_train = 'fold' + str(train_i + 1) + '_train.csv'
    txt_test = 'fold' + str(train_i + 1) + '_test.csv'
    dataset_train = MyDataset(txt_path=txt_train,
                              transform=transforms.ToTensor(),
                              target_transform=transforms.ToTensor())
    dataset_test = MyDataset(txt_path=txt_test,
                             transform=transforms.ToTensor(),
                             target_transform=transforms.ToTensor())
    train_loader = torch.utils.data.DataLoader(dataset_train,
                                               batch_size=batchsize,
                                               shuffle=True,
                                               num_workers=2)
    test_loader = torch.utils.data.DataLoader(dataset_test,
                                              batch_size=batchsize,
                                              shuffle=False,
                                              num_workers=2)

    mylog = open('logs/' + NAME + '.log', 'w')

    total_epoch = 100
    no_optim = 0
    train_epoch_best_loss = 10000
    best_test_score = 0
    for epoch in range(1, total_epoch + 1):
        data_loder_iter = iter(train_loader)
        data_loder_test = iter(test_loader)
        train_epoch_loss = 0
        index = 0

        tic = time()

        train_score = 0
        for img, mask in data_loder_iter:
            slover.set_input(img, mask)
            train_loss, pred = slover.optimize()
            train_score += dice_coeff(mask, pred, False)
            train_epoch_loss += train_loss
            index += 1

        test_sen = 0
        test_ppv = 0
        test_score = 0
        for img, mask in data_loder_test:
            slover.set_input(img, mask)
            pre_mask, _ = slover.test_batch()
            test_score += dice_coeff(mask, pre_mask, False)
            test_sen += sensitive(mask, pre_mask)
            test_ppv += positivepv(mask, pre_mask)
        test_sen /= len(data_loder_test)
        test_ppv /= len(data_loder_test)
        test_score /= len(data_loder_test)

        if test_score > best_test_score:
            print('1. the dice score up to ',
                  test_score,
                  'from ',
                  best_test_score,
                  'saving the model',
                  file=mylog,
                  flush=True)
            print('1. the dice score up to ', test_score, 'from ',
                  best_test_score, 'saving the model')
            best_test_score = test_score
            slover.save('./weights/' + NAME + '.th')

        train_epoch_loss = train_epoch_loss / len(data_loder_iter)
        train_score = train_score / len(data_loder_iter)
        print('epoch:',
              epoch,
              '    time:',
              int(time() - tic),
              'train_loss:',
              train_epoch_loss.cpu().data.numpy(),
              'train_score:',
              train_score,
              file=mylog,
              flush=True)
        print('test_dice_loss: ',
              test_score,
              'test_sen: ',
              test_sen,
              'test_ppv: ',
              test_ppv,
              'best_score is ',
              best_test_score,
              file=mylog,
              flush=True)

        print('epoch:', epoch, '    time:', int(time() - tic), 'train_loss:',
              train_epoch_loss.cpu().data.numpy(), 'train_score:', train_score)
        print('test_dice_loss: ',
              test_score,
              'test_sen: ',
              test_sen,
              'test_ppv: ',
              test_ppv,
              'best_score is ',
              best_test_score,
              file=mylog)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
    print('Finish!', file=mylog, flush=True)
    print('Finish!')
    mylog.close()
def CE_Net_Train():
    NAME = 'CE-Net' + Constants.ROOT.split('/')[-1]

    # run the Visdom
    viz = Visualizer(env=NAME)

    solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4)
    print("count", Constants.BATCHSIZE_PER_CARD)
    batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD
    print("batchsize", batchsize)

    # For different 2D medical image segmentation tasks, please specify the dataset which you use
    # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection.

    dataset = ImageFolder(root_path=Constants.ROOT, datasets='Cell')
    data_loader = torch.utils.data.DataLoader(dataset,
                                              batch_size=batchsize,
                                              shuffle=True,
                                              num_workers=4)

    dataset_val = ImageFolder(root_path='./test_data/DRIVE_dot_dash_training',
                              datasets='Cell')
    data_loader_val = torch.utils.data.DataLoader(dataset_val,
                                                  batch_size=8,
                                                  shuffle=True,
                                                  num_workers=4)

    # start the logging files
    mylog = open('logs/' + NAME + '.log', 'w')
    tic = time()

    no_optim = 0
    total_epoch = Constants.TOTAL_EPOCH
    train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS
    for epoch in range(1, total_epoch + 1):
        data_loader_iter = iter(data_loader)
        train_epoch_loss = 0
        index = 0

        for img, mask in data_loader_iter:
            # solver.load('./weights/' + NAME + '.th')
            # print("iterating the dataloader")
            solver.set_input(img, mask)
            train_loss, pred = solver.optimize()
            train_epoch_loss += train_loss
            index = index + 1

        # show the original images, predication and ground truth on the visdom.
        show_image = (img + 1.6) / 3.2 * 255.
        viz.img(name='images', img_=show_image[0, :, :, :])
        viz.img(name='labels', img_=mask[0, :, :, :])
        viz.img(name='prediction', img_=pred[0, :, :, :])

        torchvision.utils.save_image(img[0, :, :, :],
                                     "images/image_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(mask[0, :, :, :],
                                     "images/mask_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)
        torchvision.utils.save_image(pred[0, :, :, :],
                                     "images/pred_" + str(epoch) + ".jpg",
                                     nrow=1,
                                     padding=2,
                                     normalize=True,
                                     range=None,
                                     scale_each=False,
                                     pad_value=0)

        # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8)
        # x = show_image[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'RGB')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) +  '.jpg')

        # x = mask[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'L')(x)
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) +  '.jpg')

        # x = pred[0,:,:,:]
        # print(x.shape)
        # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy())
        # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) +  '.jpg')
        # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png")
        # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0)
        # x = mask[0,:,:,:]
        # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png")
        # x = pred[0,:,:,:]
        # print(x.shape)
        # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8'))

        # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png")
        print("saving images")
        print("Train_loss_for_all ", train_epoch_loss)
        print("length of (data_loader_iter) ", len(data_loader_iter))
        train_epoch_loss = train_epoch_loss / len(data_loader_iter)
        print(mylog, '********')
        print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
        print(mylog, 'train_loss:', train_epoch_loss)
        print(mylog, 'SHAPE:', Constants.Image_size)
        print('********')
        print('epoch:', epoch, '    time:', int(time() - tic))
        print('train_loss:', train_epoch_loss)
        print('SHAPE:', Constants.Image_size)

        if train_epoch_loss >= train_epoch_best_loss:
            no_optim += 1
        else:
            no_optim = 0
            train_epoch_best_loss = train_epoch_loss
            print("Saving the Weights")
            solver.save('./weights/' + NAME + '.th')
            if epoch % 100 == 0:
                solver.save('./weights/' + NAME + str(epoch) + '.th')
        if no_optim > Constants.NUM_EARLY_STOP:
            print(mylog, 'early stop at %d epoch' % epoch)
            print('early stop at %d epoch' % epoch)
            break
        if no_optim > Constants.NUM_UPDATE_LR:
            if solver.old_lr < 5e-7:
                break
            solver.load('./weights/' + NAME + '.th')
            solver.update_lr(2.0, factor=True, mylog=mylog)
        mylog.flush()

        if (epoch % 1 == 0):
            # validation save image
            print('in VALIDATION')
            # for
            data_loader_iter_val = iter(data_loader_val)
            train_epoch_loss = 0
            index = 0

            for img, mask in data_loader_iter_val:
                # solver.load('./weights/' + NAME + '.th')
                solver.set_input(img, mask)

                train_loss, pred = solver.optimize_test()
                train_epoch_loss += train_loss
                index = index + 1
                # torchvision.utils.save_image(img[0, :, :, :], "test_data/results2/image_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
                # torchvision.utils.save_image(mask[0, :, :, :], "test_data/results2/mask_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
                # torchvision.utils.save_image(pred[0, :, :, :], "test_data/results2/pred_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0)
            print("Train_loss_for_all ", train_epoch_loss)
            print("length of (data_loader_iter_val) ",
                  len(data_loader_iter_val))
            print(train_epoch_loss / len(data_loader_iter_val))
            print('++++++++++++++++++++++++++++++++++')
            # show the original images, predication and ground truth on the visdom.
            # show_image = (img + 1.6) / 3.2 * 255.
            # viz.img(name='images', img_=show_image[0, :, :, :])
            # viz.img(name='labels', img_=mask[0, :, :, :])
            # viz.img(name='prediction', img_=pred[0, :, :, :])

            torchvision.utils.save_image(img[0, :, :, :],
                                         "test_data/results4/image_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)
            torchvision.utils.save_image(mask[0, :, :, :],
                                         "test_data/results4/mask_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)
            torchvision.utils.save_image(pred[0, :, :, :],
                                         "test_data/results4/pred_" +
                                         str(epoch) + ".jpg",
                                         nrow=1,
                                         padding=2,
                                         normalize=True,
                                         range=None,
                                         scale_each=False,
                                         pad_value=0)

            # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8)
            # x = show_image[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'RGB')(x)
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) +  '.jpg')

            # x = mask[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'L')(x)
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) +  '.jpg')

            # x = pred[0,:,:,:]
            # print(x.shape)
            # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy())
            # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) +  '.jpg')
            # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png")
            # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0)
            # x = mask[0,:,:,:]
            # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png")
            # x = pred[0,:,:,:]
            # print(x.shape)
            # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8'))

            # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png")
            # print("saving images")

            # train_epoch_loss = train_epoch_loss/len(data_loader_iter)
            # print(mylog, '********')
            # print(mylog, 'epoch:', epoch, '    time:', int(time() - tic))
            # print(mylog, 'train_loss:', train_epoch_loss)
            # print(mylog, 'SHAPE:', Constants.Image_size)
            # print('********')
            # print('epoch:', epoch, '    time:', int(time() - tic))
            # print('train_loss:', train_epoch_loss)
            # print('SHAPE:', Constants.Image_size)

            # if train_epoch_loss >= train_epoch_best_loss:
            #     no_optim += 1
            # else:
            #     no_optim = 0
            #     train_epoch_best_loss = train_epoch_loss
            #     solver.save('./weights/' + NAME + '.th')
            # if no_optim > Constants.NUM_EARLY_STOP:
            #     print(mylog, 'early stop at %d epoch' % epoch)
            #     print('early stop at %d epoch' % epoch)
            #     break
            # if no_optim > Constants.NUM_UPDATE_LR:
            #     if solver.old_lr < 5e-7:
            #         break
            #     solver.load('./weights/' + NAME + '.th')
            #     solver.update_lr(2.0, factor=True, mylog=mylog)
            # mylog.flush()

    print(mylog, 'Finish!')
    print('Finish!')
    mylog.close()