mylog.write('********' + '\n') mylog.write('epoch:' + str(epoch) + ' time:' + str(int(time() - tic)) + '\n') mylog.write('train_loss:' + str(train_epoch_loss) + '\n') mylog.write('SHAPE:' + str(SHAPE) + '\n') print('********') print('epoch:', epoch, ' time:', int(time() - tic)) print('train_loss:', train_epoch_loss) print('SHAPE:', SHAPE) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss solver.save('weights/' + NAME + '.th') if no_optim > 6: print(mylog, 'early stop at %d epoch' % epoch) print('early stop at %d epoch' % epoch) break if no_optim > 3: if solver.old_lr < 5e-7: break solver.load('weights/' + NAME + '.th') solver.update_lr(5.0, factor=True, mylog=mylog) mylog.flush() mylog.write('Finish!') print('Finish!') mylog.close()
def train_operation(train_paras): sat_dir = train_paras["image_dir"] lab_dir = train_paras["gt_dir"] train_id = train_paras["train_id"] logfile_dir = train_paras["logfile_dir"] model_dir = train_paras["model_dir"] model_name = train_paras["model_name"] learning_rate = train_paras["learning_rate"] imagelist = os.listdir(sat_dir) trainlist = list(map(lambda x: x[:-8], imagelist)) # trainlist = trainlist[:1000] BATCHSIZE_PER_CARD = 2 solver = MyFrame(DUNet, learning_rate, model_name) # solver = MyFrame(Unet, dice_bce_loss, 2e-4) batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD dataset = ImageFolder(trainlist, sat_dir, lab_dir) data_loader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=0) mylog = open(logfile_dir + model_name + '.log', 'w') print("**************" + model_name + "******************", file=mylog) print("**************" + model_name + "******************") print("current train id:{}".format(train_id), file=mylog) print("current train id:{}".format(train_id)) print("batch size:{}".format(batchsize), file=mylog) print("total images: {}".format(len(trainlist))) print("total images: {}".format(len(trainlist)), file=mylog) tic = time() no_optim = 0 total_epoch = train_paras["total_epoch"] train_epoch_best_loss = 100. # solver.load('weights/dlinknet_new_lr_decoder.th') # print('* load existing model *') epoch_iter = 0 print("learning rate is {}".format(learning_rate), file=mylog) print("Precompute weight for 5 epoches", file=mylog) print("Precompute weight for 5 epoches") save_tensorboard_iter = 5 pre_compute_flag = 1 # solver.load(model_dir + model_name + '.th') # pretrain W for epoch in range(1, 6): data_loader_iter = iter(data_loader) train_epoch_loss = 0 if epoch < 5: no_optim = 0 t = 0 for img, mask in data_loader_iter: t += 1 solver.set_input(img, mask) solver.pre_compute_W(t) print('********', file=mylog) print('pre-train W::', epoch, ' time:', int(time() - tic), file=mylog) print('********') print('pre-train W:', epoch, ' time:', int(time() - tic)) print("pretrain is OVER") print("pretrain is OVER", file=mylog) step_update = False for epoch in range(1, total_epoch + 1): data_loader_iter = iter(data_loader) train_epoch_loss = 0 for img, mask in data_loader_iter: imgs = solver.set_input(img, mask) train_loss = solver.optimize(pre_compute_flag) pre_compute_flag = 0 train_epoch_loss += train_loss train_epoch_loss /= len(data_loader_iter) print('********', file=mylog) print('epoch:', epoch, ' time:', int(time() - tic), file=mylog) print('train_loss:', train_epoch_loss, file=mylog) print('SHAPE:', SHAPE, file=mylog) print('********') print('epoch:', epoch, ' time:', int(time() - tic)) print('train_loss:', train_epoch_loss) print('SHAPE:', SHAPE) if epoch % save_tensorboard_iter == 1: solver.update_tensorboard(epoch) # imgs=imgs.to(torch.device("cpu")) # solver.writer.add_graph(solver.model,imgs) print("train best loss is {}".format(train_epoch_best_loss)) print("train best loss is {}".format(train_epoch_best_loss), file=mylog) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss solver.save(model_dir + model_name + '.th') if no_optim > 6: print('early stop at %d epoch' % epoch, file=mylog) print('early stop at %d epoch' % epoch) break elif no_optim > 3: step_update = True solver.update_lr(5.0, factor=True, mylog=mylog) print("update lr by ratio 0.5") elif no_optim > 2: if solver.old_lr < 5e-7: break solver.load(model_dir + model_name + '.th') # solver.update_lr(5.0, factor=True, mylog=mylog) if step_update: solver.update_lr(5.0, factor=True, mylog=mylog) step_update = False else: solver.update_lr_poly(epoch, total_epoch, mylog, total_epoch / 40) if not step_update: solver.update_lr_poly(epoch, total_epoch, mylog, total_epoch / 40) mylog.flush() solver.close_tensorboard() print('*********************Finish!***********************', file=mylog) print('Finish!') mylog.close()
print('train_loss:', train_epoch_loss, file=mylog) print('SHAPE:', SHAPE, file=mylog) print('********') print('epoch:', epoch, ' time:', int(time() - tic)) print('train_loss:', train_epoch_loss) print('SHAPE:', SHAPE) if (epoch % 20 == 0 and epoch != 0): solver.save('weights/' + NAME + '/' + NAME + str(train_epoch_loss) + '.th') if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss solver.save('weights/' + NAME + '.th') if no_optim > 20: print('early stop at %d epoch' % epoch, file=mylog) print('early stop at %d epoch' % epoch) break if no_optim > 10: if solver.old_lr < 5e-7: break solver.load('weights/' + NAME + '.th') solver.update_lr(0.8, factor=True, mylog=mylog) mylog.flush() print('Finish!', file=mylog) print('Finish!') mylog.close()
def vessel_main(): SHAPE = (448, 448) # ROOT = 'dataset/RIM-ONE/' ROOT = './dataset/DRIVE' NAME = 'log01_dink34-UNet' + ROOT.split('/')[-1] BATCHSIZE_PER_CARD = 8 # net = UNet(n_channels=3, n_classes=2) viz = Visualizer(env="Vessel_Unet_from_scratch") solver = MyFrame(UNet, dice_bce_loss, 2e-4) batchsize = torch.cuda.device_count() * BATCHSIZE_PER_CARD dataset = ImageFolder(root_path=ROOT, datasets='DRIVE') data_loader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=4) mylog = open('logs/' + NAME + '.log', 'w') tic = time() no_optim = 0 total_epoch = 300 train_epoch_best_loss = 10000. for epoch in range(1, total_epoch + 1): data_loader_iter = iter(data_loader) train_epoch_loss = 0 index = 0 for img, mask in data_loader_iter: solver.set_input(img, mask) train_loss, pred = solver.optimize() train_epoch_loss += train_loss index = index + 1 # if index % 10 == 0: # # train_epoch_loss /= index # # viz.plot(name='loss', y=train_epoch_loss) # show_image = (img + 1.6) / 3.2 * 255. # viz.img(name='images', img_=show_image[0, :, :, :]) # viz.img(name='labels', img_=mask[0, :, :, :]) # viz.img(name='prediction', img_=pred[0, :, :, :]) show_image = (img + 1.6) / 3.2 * 255. viz.img(name='images', img_=show_image[0, :, :, :]) viz.img(name='labels', img_=mask[0, :, :, :]) viz.img(name='prediction', img_=pred[0, :, :, :]) train_epoch_loss = train_epoch_loss / len(data_loader_iter) print(mylog, '********') print(mylog, 'epoch:', epoch, ' time:', int(time() - tic)) print(mylog, 'train_loss:', train_epoch_loss) print(mylog, 'SHAPE:', SHAPE) print('********') print('epoch:', epoch, ' time:', int(time() - tic)) print('train_loss:', train_epoch_loss) print('SHAPE:', SHAPE) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss solver.save('./weights/' + NAME + '.th') if no_optim > 20: print(mylog, 'early stop at %d epoch' % epoch) print('early stop at %d epoch' % epoch) break if no_optim > 15: if solver.old_lr < 5e-7: break solver.load('./weights/' + NAME + '.th') solver.update_lr(2.0, factor=True, mylog=mylog) mylog.flush() print(mylog, 'Finish!') print('Finish!') mylog.close()
# img_out = vutils.make_grid(pre,nrow=4,normalize=True)#必须是tensor # write.add_image('predict_out',img_out,allstep)#必须是三个通道的 #可视化损失函数输出 train_epoch_loss += train_loss#所有的loss和 write.add_scalar('train_loss',train_loss,allstep) # #可视化网络参数直方图感觉影响速度 # for name,param in solver.net.named_parameters(): # write.add_histogram(name,param.data.cpu().numpy(),allstep) train_epoch_loss /= len(train_load)#平均loss print('********') print('epoch:',epoch,'time:',int(time()-tic)/60) print('train_loss:',train_epoch_loss) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss #保留结果 solver.save(modefiles) if no_optim > 6: print('early stop at %d epoch' % epoch) break if no_optim > 3: if solver.old_lr < 5e-7: break solver.load('weights/'+NAME+'.th') solver.update_lr(5.0, factor = True)
def CE_Net_Train(train_i=0): NAME = 'fold' + str(i + 1) + '_6CE-Net' + Constants.ROOT.split('/')[-1] solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4) batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD #4 # For different 2D medical image segmentation tasks, please specify the dataset which you use # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection. txt_train = 'fold' + str(train_i + 1) + '_train.csv' txt_test = 'fold' + str(train_i + 1) + '_test.csv' dataset_train = MyDataset(txt_path=txt_train, transform=transforms.ToTensor(), target_transform=transforms.ToTensor()) dataset_test = MyDataset(txt_path=txt_test, transform=transforms.ToTensor(), target_transform=transforms.ToTensor()) train_loader = torch.utils.data.DataLoader(dataset, batchsize=batchsize, shuffle=True, num_workers=2) test_loader = torch.utils.data.DataLoader(dataset, batchsize=batchsize, shuffle=False, num_workers=2) # start the logging files mylog = open('logs/' + NAME + '.log', 'w') no_optim = 0 total_epoch = Constants.TOTAL_EPOCH # 300 train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS # 10000 best_test_score = 0 for epoch in range(1, total_epoch + 1): data_loader_iter = iter(train_loader) data_loader_test = iter(test_loader) train_epoch_loss = 0 index = 0 tic = time() # train for img, mask in data_loader_iter: solver.set_input(img, mask) train_loss, pred = solver.optimize() train_epoch_loss += train_loss index = index + 1 # test test_sen = 0 test_ppv = 0 test_score = 0 for img, mask in data_loader_test: solver.set_input(img, mask) pre_mask, _ = solver.test_batch() test_score += dice_coeff(y_test, pre_mask, False) test_sen += sensitive(y_test, pre_mask) # test_sen = test_sen.cpu().data.numpy() test_ppv += positivepv(y_test, pre_mask) # test_ppv = test_ppv.cpu().data.numpy() print(test_sen / len(data_loader_test), test_ppv / len(data_loader_test), test_score / len(data_loader_test)) # solver.set_input(x_test, y_test) # pre_mask, _ = solver.test_batch() # test_score = dice_coeff(y_test, pre_mask, False) # test_sen = sensitive(y_test, pre_mask) # test_sen = test_sen.cpu().data.numpy() # test_ppv = positivepv(y_test, pre_mask) # test_ppv = test_ppv.cpu().data.numpy() # print('111111111111111111111',type(test_score)) # # show the original images, predication and ground truth on the visdom. # show_image = (img + 1.6) / 3.2 * 255. # viz.img(name='images', img_=show_image[0, :, :, :]) # viz.img(name='labels', img_=mask[0, :, :, :]) # viz.img(name='prediction', img_=pred[0, :, :, :]) if test_score > best_test_score: print('1. the dice score up to ', test_score, 'from ', best_test_score, 'saving the model') best_test_score = test_score solver.save('./weights/' + NAME + '.th') train_epoch_loss = train_epoch_loss / len(data_loader_iter) # print(mylog, '********') print('epoch:', epoch, ' time:', int(time() - tic), 'train_loss:', train_epoch_loss.cpu().data.numpy(), file=mylog, flush=True) print('test_dice_loss: ', test_score, 'test_sen: ', test_sen, 'test_ppv: ', test_ppv, 'best_score is ', best_test_score, file=mylog, flush=True) print('********') print('epoch:', epoch, ' time:', int(time() - tic), 'train_loss:', train_epoch_loss.cpu().data.numpy()) print('test_dice_score: ', test_score, 'test_sen: ', test_sen, 'test_ppv: ', test_ppv, 'best_score is ', best_test_score) # print('train_loss:', train_epoch_loss) # print('SHAPE:', Constants.Image_size) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss # solver.save('./weights/' + NAME + '.th') # if no_optim > Constants.NUM_EARLY_STOP: # print(mylog, 'early stop at %d epoch' % epoch) # print('early stop at %d epoch' % epoch) # break if no_optim > Constants.NUM_UPDATE_LR: if solver.old_lr < 5e-7: break if solver.old_lr > 5e-4: solver.load('./weights/' + NAME + '.th') solver.update_lr(1.5, factor=True, mylog=mylog) print('Finish!', file=mylog, flush=True) print('Finish!') mylog.close()
def CE_Net_Train(): NAME = 'CE-Net' + Constants.ROOT.split('/')[-1] # run the Visdom viz = Visualizer(env=NAME) solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4) batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD # For different 2D medical image segmentation tasks, please specify the dataset which you use # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection. dataset = ImageFolder(root_path=Constants.ROOT, datasets='DRIVE') data_loader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=4) # start the logging files mylog = open('logs/' + NAME + '.log', 'w') tic = time() no_optim = 0 total_epoch = Constants.TOTAL_EPOCH train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS for epoch in range(1, total_epoch + 1): data_loader_iter = iter(data_loader) train_epoch_loss = 0 index = 0 for img, mask in data_loader_iter: solver.set_input(img, mask) train_loss, pred = solver.optimize() train_epoch_loss += train_loss index = index + 1 # show the original images, predication and ground truth on the visdom. show_image = (img + 1.6) / 3.2 * 255. viz.img(name='images', img_=show_image[0, :, :, :]) viz.img(name='labels', img_=mask[0, :, :, :]) viz.img(name='prediction', img_=pred[0, :, :, :]) train_epoch_loss = train_epoch_loss / len(data_loader_iter) print(mylog, '********') print(mylog, 'epoch:', epoch, ' time:', int(time() - tic)) print(mylog, 'train_loss:', train_epoch_loss) print(mylog, 'SHAPE:', Constants.Image_size) print('********') print('epoch:', epoch, ' time:', int(time() - tic)) print('train_loss:', train_epoch_loss) print('SHAPE:', Constants.Image_size) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss solver.save('./weights/' + NAME + '.th') if no_optim > Constants.NUM_EARLY_STOP: print(mylog, 'early stop at %d epoch' % epoch) print('early stop at %d epoch' % epoch) break if no_optim > Constants.NUM_UPDATE_LR: if solver.old_lr < 5e-7: break solver.load('./weights/' + NAME + '.th') solver.update_lr(2.0, factor=True, mylog=mylog) mylog.flush() print(mylog, 'Finish!') print('Finish!') mylog.close()
def CE_Net_Train(): NAME = 'CE-Net' + Constants.ROOT.split('/')[-1] # run the Visdom viz = Visualizer(env=NAME) solver = MyFrame(CE_Net_, dice_bce_loss, 2e-4) print("count", Constants.BATCHSIZE_PER_CARD) batchsize = torch.cuda.device_count() * Constants.BATCHSIZE_PER_CARD print("batchsize", batchsize) # For different 2D medical image segmentation tasks, please specify the dataset which you use # for examples: you could specify "dataset = 'DRIVE' " for retinal vessel detection. dataset = ImageFolder(root_path=Constants.ROOT, datasets='Cell') data_loader = torch.utils.data.DataLoader(dataset, batch_size=batchsize, shuffle=True, num_workers=4) dataset_val = ImageFolder(root_path='./test_data/DRIVE_dot_dash_training', datasets='Cell') data_loader_val = torch.utils.data.DataLoader(dataset_val, batch_size=8, shuffle=True, num_workers=4) # start the logging files mylog = open('logs/' + NAME + '.log', 'w') tic = time() no_optim = 0 total_epoch = Constants.TOTAL_EPOCH train_epoch_best_loss = Constants.INITAL_EPOCH_LOSS for epoch in range(1, total_epoch + 1): data_loader_iter = iter(data_loader) train_epoch_loss = 0 index = 0 for img, mask in data_loader_iter: # solver.load('./weights/' + NAME + '.th') # print("iterating the dataloader") solver.set_input(img, mask) train_loss, pred = solver.optimize() train_epoch_loss += train_loss index = index + 1 # show the original images, predication and ground truth on the visdom. show_image = (img + 1.6) / 3.2 * 255. viz.img(name='images', img_=show_image[0, :, :, :]) viz.img(name='labels', img_=mask[0, :, :, :]) viz.img(name='prediction', img_=pred[0, :, :, :]) torchvision.utils.save_image(img[0, :, :, :], "images/image_" + str(epoch) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) torchvision.utils.save_image(mask[0, :, :, :], "images/mask_" + str(epoch) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) torchvision.utils.save_image(pred[0, :, :, :], "images/pred_" + str(epoch) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8) # x = show_image[0,:,:,:] # print(x.shape) # pil_im = transforms.ToPILImage(mode = 'RGB')(x) # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) + '.jpg') # x = mask[0,:,:,:] # print(x.shape) # pil_im = transforms.ToPILImage(mode = 'L')(x) # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) + '.jpg') # x = pred[0,:,:,:] # print(x.shape) # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy()) # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) + '.jpg') # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png") # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0) # x = mask[0,:,:,:] # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png") # x = pred[0,:,:,:] # print(x.shape) # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8')) # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png") print("saving images") print("Train_loss_for_all ", train_epoch_loss) print("length of (data_loader_iter) ", len(data_loader_iter)) train_epoch_loss = train_epoch_loss / len(data_loader_iter) print(mylog, '********') print(mylog, 'epoch:', epoch, ' time:', int(time() - tic)) print(mylog, 'train_loss:', train_epoch_loss) print(mylog, 'SHAPE:', Constants.Image_size) print('********') print('epoch:', epoch, ' time:', int(time() - tic)) print('train_loss:', train_epoch_loss) print('SHAPE:', Constants.Image_size) if train_epoch_loss >= train_epoch_best_loss: no_optim += 1 else: no_optim = 0 train_epoch_best_loss = train_epoch_loss print("Saving the Weights") solver.save('./weights/' + NAME + '.th') if epoch % 100 == 0: solver.save('./weights/' + NAME + str(epoch) + '.th') if no_optim > Constants.NUM_EARLY_STOP: print(mylog, 'early stop at %d epoch' % epoch) print('early stop at %d epoch' % epoch) break if no_optim > Constants.NUM_UPDATE_LR: if solver.old_lr < 5e-7: break solver.load('./weights/' + NAME + '.th') solver.update_lr(2.0, factor=True, mylog=mylog) mylog.flush() if (epoch % 1 == 0): # validation save image print('in VALIDATION') # for data_loader_iter_val = iter(data_loader_val) train_epoch_loss = 0 index = 0 for img, mask in data_loader_iter_val: # solver.load('./weights/' + NAME + '.th') solver.set_input(img, mask) train_loss, pred = solver.optimize_test() train_epoch_loss += train_loss index = index + 1 # torchvision.utils.save_image(img[0, :, :, :], "test_data/results2/image_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) # torchvision.utils.save_image(mask[0, :, :, :], "test_data/results2/mask_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) # torchvision.utils.save_image(pred[0, :, :, :], "test_data/results2/pred_"+str(epoch) + '_' + str(index) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) print("Train_loss_for_all ", train_epoch_loss) print("length of (data_loader_iter_val) ", len(data_loader_iter_val)) print(train_epoch_loss / len(data_loader_iter_val)) print('++++++++++++++++++++++++++++++++++') # show the original images, predication and ground truth on the visdom. # show_image = (img + 1.6) / 3.2 * 255. # viz.img(name='images', img_=show_image[0, :, :, :]) # viz.img(name='labels', img_=mask[0, :, :, :]) # viz.img(name='prediction', img_=pred[0, :, :, :]) torchvision.utils.save_image(img[0, :, :, :], "test_data/results4/image_" + str(epoch) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) torchvision.utils.save_image(mask[0, :, :, :], "test_data/results4/mask_" + str(epoch) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) torchvision.utils.save_image(pred[0, :, :, :], "test_data/results4/pred_" + str(epoch) + ".jpg", nrow=1, padding=2, normalize=True, range=None, scale_each=False, pad_value=0) # x = torch.tensor([[1,2,3],[4,5,6]], dtype = torch.uint8) # x = show_image[0,:,:,:] # print(x.shape) # pil_im = transforms.ToPILImage(mode = 'RGB')(x) # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_' + str(epoch) + '.jpg') # x = mask[0,:,:,:] # print(x.shape) # pil_im = transforms.ToPILImage(mode = 'L')(x) # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_' + str(epoch) + '.jpg') # x = pred[0,:,:,:] # print(x.shape) # pil_im = transforms.ToPILImage(mode = 'HSV')(x.detach().cpu().numpy()) # pil_im.save('/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_' + str(epoch) + '.jpg') # (x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/image_" + str(epoch) + ".png") # cv2.imwrite('imagename.jpg', x.detach().numpy().astype('uint8')).transpose(2,1,0) # x = mask[0,:,:,:] # # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/mask_" + str(epoch) + ".png") # x = pred[0,:,:,:] # print(x.shape) # cv2.imwrite('imagename2.jpg', x.detach().numpy().astype('uint8')) # F.to_pil_image(x.detach().numpy()).save("/home/videsh/Downloads/Chandan/paper_implementation/CE-Net-master/images/prediction_" + str(epoch) + ".png") # print("saving images") # train_epoch_loss = train_epoch_loss/len(data_loader_iter) # print(mylog, '********') # print(mylog, 'epoch:', epoch, ' time:', int(time() - tic)) # print(mylog, 'train_loss:', train_epoch_loss) # print(mylog, 'SHAPE:', Constants.Image_size) # print('********') # print('epoch:', epoch, ' time:', int(time() - tic)) # print('train_loss:', train_epoch_loss) # print('SHAPE:', Constants.Image_size) # if train_epoch_loss >= train_epoch_best_loss: # no_optim += 1 # else: # no_optim = 0 # train_epoch_best_loss = train_epoch_loss # solver.save('./weights/' + NAME + '.th') # if no_optim > Constants.NUM_EARLY_STOP: # print(mylog, 'early stop at %d epoch' % epoch) # print('early stop at %d epoch' % epoch) # break # if no_optim > Constants.NUM_UPDATE_LR: # if solver.old_lr < 5e-7: # break # solver.load('./weights/' + NAME + '.th') # solver.update_lr(2.0, factor=True, mylog=mylog) # mylog.flush() print(mylog, 'Finish!') print('Finish!') mylog.close()