def __init__(self, args, writer, device): """ :param args: """ super(Meta, self).__init__() self.device = device self.update_lr = args.update_lr self.meta_lr = args.meta_lr self.n_way = args.n_way self.k_spt = args.k_spt self.k_qry = args.k_qry self.task_num = args.task_num self.update_step = args.update_step self.update_step_test = args.update_step_test self.writer = writer self.in_channels = args.imgc self.out_channels = args.output_channel self.epsilon = 1e-10 self.net = UNet(in_channels=args.imgc, out_channels=args.output_channel) self.net_param = [] for param in self.net.parameters(): self.net_param.append(param.clone().data.to(device)) logging.info(f'Network:\n' f'\t{self.net.in_channels} input channels\n' f'\t{self.net.out_channels} output channels (classes)') if args.load: net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f'Model loaded from {args.load}') # define loss self.mse_loss_fn = torch.nn.MSELoss(reduction='none') self.ssim_loss_fc = pytorch_msssim.SSIM(window_size=7) self.contour_loss = Contour_loss(K=5) # define optimizer for meta learning if args.optimizer == "adam": self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr) elif args.optimizer == "rmsprop": self.meta_optim = optim.RMSprop(self.net.parameters(), lr=self.meta_lr, weight_decay=self.weight_decay) else: raise ValueError("Wrong Optimzer !")
def main(): # set logging and writer writer = SummaryWriter(log_dir=args.dir_checkpoint) logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Using device {device}') synthesis_test = Synthesis_Image(args.dataset, mode='test', batchsz = args.batchsize, test_data=args.test_data) # define loss mse_loss_fn = torch.nn.MSELoss(reduction='none') ssim_loss_fc = pytorch_msssim.SSIM(window_size = 7) contour_loss = Contour_loss(K=5) # define net net = UNet(in_channels=args.imgc, out_channels=args.output_channel) net.to(device) if args.load: net.load_state_dict(torch.load(args.load)) print('Model loaded from {}'.format(args.load)) try: test(net,mse_loss_fn,ssim_loss_fc,contour_loss,synthesis_test,device,writer) except KeyboardInterrupt: try: sys.exit(0) except SystemExit: os._exit(0)
def __init__(self, args, conv=common.default_conv): super(EDSR, self).__init__() n_resblocks = args.n_resblocks n_feats = args.n_feats kernel_size = 3 scale = args.scale[0] act = nn.ReLU(True) url_name = 'r{}f{}x{}'.format(n_resblocks, n_feats, scale) if url_name in url: self.url = url[url_name] else: self.url = None self.sub_mean = common.MeanShift(args.rgb_range) self.add_mean = common.MeanShift(args.rgb_range, sign=1) # define head module m_head = [conv(args.n_colors, n_feats, kernel_size)] # define body module m_body = [ common.ResBlock(conv, n_feats, kernel_size, act=act, res_scale=args.res_scale) for _ in range(n_resblocks) ] m_body.append(conv(n_feats, n_feats, kernel_size)) # define tail module m_tail = [ common.Upsampler(conv, scale, n_feats, act=False), conv(n_feats, args.n_colors, kernel_size) ] self.head = nn.Sequential(*m_head) self.body = nn.Sequential(*m_body) self.tail = nn.Sequential(*m_tail) self.unet = UNet(args.n_colors, args.n_colors)
# 检查文件目录 config.result_path = os.path.join(config.result_path, config.Task_name) print(config.result_path) config.model_path = os.path.join(config.result_path, 'models') config.log_dir = os.path.join(config.result_path, 'logs') if not os.path.exists(config.result_path): os.makedirs(config.result_path) os.makedirs(config.model_path) os.makedirs(config.log_dir) # 记录训练配置 f = open(os.path.join(config.result_path, 'config.txt'), 'w') for key in config.__dict__: print('%s: %s' % (key, config.__getattribute__(key)), file=f) f.close() # 记录训练过程 config.record_file = os.path.join(config.result_path, 'record.txt') f = open(config.record_file, 'a') f.close() # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') print(device) # 加载网络,图片单通道1,分类为1。 train_net = UNet(n_channels=1, n_classes=1) # 将网络拷贝到deivce中 train_net.to(device) train(train_net, device, config)
import glob import numpy as np import torch import os import cv2 from model.unet_model import UNet if __name__ == '__main__': device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') net = UNet(n_channels=1, n_classes=1) net.to(device=device) net.load_state_dict(torch.load('model.pth', map_location=device)) net.eval() tests_path = glob.glob('D:/Research/Dataset/3DOH50K/testset/*.jpg') for test_path in tests_path: save_res_path = test_path.split('.')[0] + '_res.jpg' img = cv2.imread(test_path) img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) img = img.reshape(1, 1, img.shape[0], img.shape[1]) img_tensor = torch.from_numpy(img) img_tensor = img_tensor.to(device=device, dtype=torch.float32) pred = net(img_tensor) pred = np.array(pred.data.cpu()[0])[0] pred[pred >= 0.5] = 255 pred[pred < 0.5] = 0 cv2.imwrite(save_res_path, pred)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') criterion = nn.CrossEntropyLoss() patch_size = 64 batch_size = 8 num_class = 13 save_dir = "./results" # 加载数据集 data_dir = "/home/cym/Datasets/StData-12/F3_block/" dataset = F3DS(data_dir, ptsize=patch_size, train=False) test_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=False) hw = dataset.hw origin_mask = np.zeros((hw[0], hw[1])) net = UNet(n_channels=1, n_classes=num_class, bilinear=False) net.load_state_dict(torch.load('models2/best_model.pth')) net.to(device=device) net.eval() print("net prepare done") if not os.path.exists(save_dir): os.makedirs(f"{save_dir}/img") os.makedirs(f"{save_dir}/label") os.makedirs(f"{save_dir}/predlabel") all_images_num = 0. all_acc = 0. img_idx = 0 for batch_idx, (image, label, hys, wxs) in enumerate(test_loader):
# 加载测试集 # TestData_dataset = TestData_Loader('D:/Research/Dataset/3DOH50K/testset/') # tests_path = 'D:/Research/Dataset/3DOH50K/testset/masks/' TestData_dataset = TestData_Loader('D:/Research/Dataset/3DOH50K/notset/') tests_path = 'D:/Research/Dataset/3DOH50K/notset/masks_test/' print("数据个数: ", len(TestData_dataset)) test_loader = torch.utils.data.DataLoader(dataset=TestData_dataset, batch_size=batch_size, shuffle=False) out_dir = "model_BCE_bs16.pkl" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # load model model = UNet(n_channels=3, n_classes=1).to(device=device) model.load_state_dict(torch.load(out_dir)) model.eval() batchcnt = 0 with torch.no_grad(): for image in test_loader: image = image.to(device=device, dtype=torch.float32) pred = model(image) # 写入batch中的每个数据 cnt = 0 for mask in pred: # ii = (image[cnt] * 255.0).to(device=device, dtype=torch.uint8) # imShow = np.array(ii.data.cpu())
predShow = predShow.reshape(512, 512, 1) cv2.imshow("image", imShow) cv2.imshow("mask", maskShow) cv2.imshow("pred", predShow) cv2.waitKey() # 计算loss loss = criterion(pred, label) print('Loss/train: ', loss.item()) # 保存loss值最小的参数 if loss < best_loss: best_loss = loss torch.save(net.state_dict(), 'model.pth') # 更新参数 loss.backward() optimizer.step() if __name__ == '__main__': # 选择设备 device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载网络 net = UNet(n_channels=3, n_classes=1) # 网络拷贝到device中 net.to(device=device) #指定训练集 data_path = 'data/train/' # 开始训练 train_net(net, device, data_path)
shuffle=True) # TrainData_dataset = TrainData_Loader('D:/Research/Dataset/3DOH50K/notset/') # print("数据个数: ", len(TrainData_dataset)) # train_loader = torch.utils.data.DataLoader( # dataset=TrainData_dataset, # batch_size=3, # shuffle=False # ) print(len(train_loader)) out_dir = "model_Dice_ep40_bs16.pkl" device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # load model model = UNet(n_channels=3, n_classes=1).to(device=device) # 定义下降算法 optimizer = optim.Adam(model.parameters(), lr=3e-4) # optimizer = optim.RMSprop(model.parameters(), lr=1e-3, weight_decay=1e-8, momentum=0.9) # 定义loss # criterion = nn.L1Loss(reduction='mean') # criterion = nn.BCEWithLogitsLoss() criterion = DiceLoss() best_loss = float('inf') # 训练 epochs = 40 for epoch in range(epochs): model.train()
def predict(in_channel, model_path, data_path, light=False): # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载网络,图片单通道,分类为1。 if light: net = UNet_light(n_channels=in_channel, n_classes=1) else: net = UNet(n_channels=in_channel, n_classes=1) # net = Unet_v2(in_channels=in_channel, n_classes=1) # 将网络拷贝到deivce中 net.to(device=device) # 加载模型参数 net.load_state_dict(torch.load(model_path, map_location=device)) # 测试模式 net.eval() # 读取所有图片路径 with open(os.path.join(data_path, 'valid.pkl'), "rb") as f: valid = pickle.load(f) tests_path = [os.path.join(data_path, path) for path in valid] # 遍历素有图片 for test_path in tqdm(tests_path): # 保存结果地址 save_res_path = test_path.replace("train", "valid_predict") # 读取图片 img = cv2.imread(test_path) img_shape = img.shape if in_channel == 1: # 转为灰度图 img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) # 转为batch为1,通道为1,大小为512*512的数组 img = transforms.ToTensor()(img) else: # 转为batch为1,通道为3,大小为512*512的数组 img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) PIL_image = Image.fromarray(img) transform = transforms.Compose([ transforms.Resize((img_shape[0] // 2, img_shape[1] // 2)), transforms.ToTensor(), #数据归一化到[0,1],输入通道转换在前 transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5)), # 数据归一化到[-1,1] ]) img = transform(PIL_image) img = img.unsqueeze(0) # 加入batch维度 # # 转为tensor # img_tensor = torch.from_numpy(img) # 将tensor拷贝到device中,只用cpu就是拷贝到cpu中,用cuda就是拷贝到cuda中。 img = img.to(device=device, dtype=torch.float32) # 预测 pred = net(img) # 提取结果 pred = np.array(pred.data.cpu()[0])[0] # 处理结果 pred[pred >= 0.5] = 255 pred[pred < 0.5] = 0 # 保存图片 cv2.imwrite(save_res_path, pred)
import glob import numpy as np import torch import os import cv2 from model.unet_model import UNet if __name__ == "__main__": # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 加载网络,图片单通道,分类为1。 net = UNet(n_channels=1, n_classes=1) # 将网络拷贝到deivce中 net.to(device=device) # 加载模型参数 net.load_state_dict( torch.load('/Users/manmi/Documents/GitHub/unet/best_model.pth', map_location=device)) # 测试模式 net.eval() # 读取所有图片路径 tests_path = glob.glob( '/Users/manmi/Documents/GitHub/unet/data/test/*.png') # 遍历素有图片 for test_path in tests_path: # 保存结果地址 save_res_path = test_path.split('.')[0] + '_res.png' # 读取图片 img = cv2.imread(test_path) # 转为灰度图 img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
def main( model_name, from_file=True, to_file=False, validation=True, test=["precision", "recall", "f1", "accuracy", "jaccard"], concat=False, plot=True, ): """ Parameters ---------- model_name : str Which model to do things with. This is assumed to be both the name of the directory in which parameters are stored, and the name of the parameters file. from_file : bool Whether data should be loaded from a file; otherwise it will be generated. The file should: - be in the same directory as the model parameters - be called "data.npz" - contain 4 arrays: "val_predictions", "val_labels", "test_predictions" and "test_labels". Labels are expected to be floats and will be thresholded. Predictions are expected to be raw (not probabilities). to_file : bool If data is generated (not loaded from a file), whether to save to a file in the model directory, according to the form described in from_file. Irrelevant when from_file is set to True. validation : bool Whether to go through validation steps (to find the best threshold). test : list of str Which metrics to use for testing (evaluate the model with a given threshold). The results are stored in a txt file called "test_results.txt" in the model directory. If empty then testing is skipped. concat : bool During testing, whether to compute each metric once, on the concatenation of the whole test set. plot : bool Whether to show plots during run. """ model_dir = os.path.join(dir_models, model_name) params_file = os.path.join(model_dir, model_name) new_section = "=" * 50 print("Importing model parameters from {}".format(params_file)) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = UNet(n_channels=3, n_classes=1, bilinear=False) model = model.to(device) model.load_state_dict( torch.load(params_file, map_location=torch.device("cpu"))) # File where data are stored, or will be if they aren't already data_file = os.path.join(model_dir, "data.npz") print(new_section) if from_file: print("Loading data") arrays = np.load(data_file) val_predictions, val_labels, test_predictions, test_labels = arrays.values( ) else: print("Generating data") dir_data_validation = os.path.join(dir_data, "validation") dir_data_test = os.path.join(dir_data, "test") _, validation_dl, test_dl = load_data( dir_data_validation=dir_data_validation, dir_data_test=dir_data_test, prop_noPV_training=0, # Has no impact min_rescale_images=0, # Has no impact batch_size=100, # All of them ) model.eval() with torch.no_grad(): # Get images and labels from both DataLoaders val_images, val_labels = next(iter(validation_dl)) test_images, test_labels = next(iter(test_dl)) val_images = val_images.to(device, dtype=torch.float32) test_images = test_images.to(device, dtype=torch.float32) # Make predictions (predictions are not probabilities at this stage) print("Running model on data") val_predictions = model(val_images) test_predictions = model(test_images) # Convert to numpy arrays for computing val_predictions = np.squeeze(val_predictions.cpu().numpy()) val_labels = np.squeeze(val_labels.cpu().numpy()) test_predictions = np.squeeze(test_predictions.cpu().numpy()) test_labels = np.squeeze(test_labels.cpu().numpy()) # Save to file as numpy arrays if to_file: print("Saving results to file") np.savez_compressed( data_file, val_predictions=val_predictions, val_labels=val_labels, test_predictions=test_predictions, test_labels=test_labels, ) threshold_true = 0.5 val_labels = np.where(val_labels > threshold_true, 1, 0) test_labels = np.where(test_labels > threshold_true, 1, 0) if validation: n_thresholds = 101 print(new_section) print("Validation starting") precision, recall, f1_scores, best_threshold = find_best_threshold( val_predictions, val_labels, n_thresholds, concat=concat, plot=plot) print(f"Found best threshold to be {best_threshold:.4f}") if to_file: precision_lower, precision_mid, precision_upper = ( row for row in summary_stats(precision)) f1_lower, f1_mid, f1_upper = (row for row in summary_stats(f1_scores)) _, recall_mid, _ = (row for row in summary_stats(recall)) results_summary = np.c_[np.linspace(0, 1, n_thresholds), precision_lower, precision_mid, precision_upper, recall_mid, f1_lower, f1_mid, f1_upper] results_file = os.path.join(model_dir, "prec_rec_f1.txt") print("Saving results to {}".format(results_file)) np.savetxt( results_file, results_summary, delimiter=" ", header= f"Threshold: {best_threshold:.3f}\nthresholds precision_lower precision_mid precision_upper recall_mid f1_lower f1_mid f1_upper" ) if test: print(new_section) print("Testing starting with metrics:") print(", ".join(test)) results = test_model(test_predictions, test_labels, best_threshold, concat, *test) print(results) summary_type = "median" results_file = os.path.join( model_dir, "test_{}results.txt".format("concat_" if concat else "")) if concat: results_summary = np.transpose(results) print("Results:") else: results_summary = np.transpose( summary_stats(results, type=summary_type)) print(f"Summary statistics are based on the {summary_type}") print("Results (lower, mid-point, upper):") print(f"\tBest threshold = {best_threshold:.4f}") for i, measure in enumerate(test): print("\t{}: {}".format(measure, results_summary[i, :])) print(new_section) print("Saving results to {}".format(results_file)) np.savetxt( results_file, results_summary, fmt="%.4f", delimiter=" ", header=f"Threshold: {best_threshold:.3f}\n{' '.join(test)}", ) print("\n")
def main(): # set logging and writer writer = SummaryWriter( log_dir=args.dir_checkpoint + "run", comment= f'Learning Rate_{args.meta_lr}_Batch size_{args.batchsize}_Image Scale_{args.imgsz}' ) logging.basicConfig(level=logging.INFO, format='%(levelname)s: %(message)s') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') logging.info(f'Using device {device}') # define net net = UNet(in_channels=args.imgc, out_channels=args.output_channel) net.to(device) if args.load: net.load_state_dict(torch.load(args.load)) print('Model loaded from {}'.format(args.load)) # define loss mse_loss_fn = torch.nn.MSELoss(reduction='none') ssim_loss_fc = pytorch_msssim.SSIM(window_size=7) contour_loss = Contour_loss(K=5) # define optimizer for meta learning if args.optimizer == "adam": meta_optim = optim.Adam(net.parameters(), lr=args.meta_lr) elif args.optimizer == "rmsprop": meta_optim = optim.RMSprop(net.parameters(), lr=args.meta_lr, weight_decay=args.weight_decay) else: raise ValueError("Wrong Optimzer !") # define step scheduler scheduler = optim.lr_scheduler.StepLR(meta_optim, step_size=args.step_size, gamma=args.step_adjust) # define global loss best_model_loss = 10000000 # batch(batch set) of meta training set for each tasks and for meta testing synthesis_train = Synthesis_Image(args.dataset, mode='train_normal', batchsz=args.batchsize) synthesis_val = Synthesis_Image(args.dataset, mode='val_normal', batchsz=args.batchsize) try: train(net, synthesis_train, synthesis_val, device, best_model_loss, writer, scheduler, mse_loss_fn, ssim_loss_fc, contour_loss, meta_optim) except KeyboardInterrupt: torch.save(net.state_dict(), args.dir_checkpoint + 'INTERRUPTED.pth') logging.info('Saved interrupt') try: sys.exit(0) except SystemExit: os._exit(0)
"gtFine", categories, transform=transforms.Compose([ Resize(_IMAGE_SIZE_), Normalize(), ToTensor(), #TODO: Apply random color changes #TODO: Apply random spatial changes (rotation, flip etc) ])) trainloader = DataLoader(cityscapes_dataset, batch_size=8, shuffle=True, num_workers=4) model = UNet(n_classes=len(categories), in_channels=_NUM_CHANNELS_, writer=writer) if torch.cuda.device_count() >= 1: print("Training model on ", torch.cuda.device_count(), "GPUs!") model = nn.DataParallel(model) # softmax = nn.Softmax2d() # criterion = nn.BCELoss() # criterion = nn.CrossEntropyLoss() criterion = nn.BCEWithLogitsLoss() optimizer = optim.Adam(model.parameters(), lr=0.001, weight_decay=0.0001) # Network training epoch_data = {} float_type = torch.cuda.FloatTensor if torch.cuda.is_available( ) else torch.FloatTensor
dir_train = args.dir_train dir_test = args.dir_test bs = args.batchsize # Data Loader # dataset_train = img_seg_ldr(data_dir=dir_train) # train_loader = DataLoader(dataset_train, batch_size=bs, shuffle=True) dataset_test = img_seg_ldr(data_dir=dir_test) test_loader = DataLoader(dataset_test, batch_size=1, shuffle=True) # Device identification device = torch.device( 'cuda' if torch.cuda.is_available() else 'cpu' ) # Try to find out if the computer have a CUDA with Nivida GPU, else we will use CPU to work # Model if model == "unet": net = UNet(n_channels=3, n_classes=4).to(device) if model == "unet3": net = UNet3(n_channels=3, n_classes=4).to(device) if model == "unet2": net = UNet2(n_channels=3, n_classes=4).to(device) if model == "resunet": net = Unet_Resnet(in_channels=3).to(device) if model == "unetpp": net = UNetpp(in_ch=3, out_ch=4).to(device) if model == "denseunet": net = FCDenseNet57(n_classes=4).to(device) # Loss Function criterion = nn.CrossEntropyLoss(weight=torch.from_numpy(np.array( [1, 2, 2, 1])).type(torch.FloatTensor).to(device), reduction='sum')
def main( num_epochs: int = 80, learning_rate: float = 1e-3, optimizer_type: str = "ADAM", loss: str = "BCE", use_scheduler: bool = True, milestones_scheduler: list = [50], gamma_scheduler: float = 0.1, batch_size: int = 32, dir_data_training: str = "../data/train", dir_data_validation: str = "../data/validation", prop_noPV_training: float = 0.5, min_rescale_images: float = 0.6, file_losses: str = "losses.txt", saving_frequency: int = 2, weight_for_positive_class: float = 1.0, save_model_parameters: bool = False, load_model_parameters: bool = False, dir_for_model_parameters: str = "../saved_models", filename_model_parameters_to_load: str = None, filename_model_parameters_to_save: str = None, ): """ Main training function with tunable parameters. Parameters ---------- num_epochs : int, optional Number of epochs to train. The default is 80. learning_rate : float, optional Learning rate of the Optimizer. The default is 1e-3. optimizer_type : str, optional Can be "ADAM" or "SGD". The default is "ADAM". loss : str, optional Cane be "BCE" of "L1". The default is "BCE". use_scheduler : bool If True, use a MultiStepLR. You should the next two parameters if used. The default is True. milestones_scheduler : list List of epochs at which to adapt the learning rate. The default is [50]. gamma_scheduler : float Value by which to multiply the learning rate at each of the previously define milestone epochs. The default is 0.1. batch_size : int, optional Number of samples per batch in the Dataloaders. The default is 32. dir_data_training : str, optional Directory where the folders "images/", "labels/" and "noPV/" are for the training set. dir_data_validation : str, optional Directory where the folders "images/", "labels/" and "noPV/" are for the validation set. prop_noPV_training : float, optional Proportion noPV images to add compared to the total amount of PV images in the train set. The default is 0.5. min_rescale_images : float, optional Minimum proportion of the image to keep for the RandomResizedCrop transform. The default is 0.6. file_losses : str, optional Name of the files where to write the Train and test losses during training. The default is "losses.txt". saving_frequency : int, optional Frequency (in number of epochs) at which to write the train and test losses in the file. Small frequency is used if high risk that training might be interrupted to avoid too much lost data. The default is 2. weight_for_positive_class : float, optional Weight for the positive class in the Binary Cross entropy loss. The default is 1.0. save_model_parameters : bool, optional If True saves the model at the end of training. The default is False. load_model_parameters : bool, optional If True loads defined parameters in the model before training. The default is False. dir_for_model_parameters : str, optional Diretory where saved parameters are stored. The default is "../saved_models". filename_model_parameters_to_load : str, optional Filename of the parameters to load before training. Should be specified if load_model_parameters is True. The default is None. filename_model_parameters_to_save : str, optional Filename of the parameters to save after training. Should be defined is save_model_parameters is True. The default is None. Returns ------- model : torch.nn.Module Model after training. avg_train_error : list of float List of Train errors or losses after each epoch. avg_validation_error : list of float List of Validation errors or losses after each epoch. """ device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print("GPU is {}available.".format( "" if torch.cuda.is_available() else "NOT ")) # Instantiate the dataLoaders roof_dataloader_train, roof_dataloader_validation, roof_dataloader_test = load_data( prop_noPV_training, min_rescale_images, batch_size, dir_data_training, dir_data_validation, ) if loss == "BCE": # Create Binary cross entropy loss weighted according to positive pixels. # pos_weight > 1 increases recall. # pos_weight < 1 increases precision. pos_weight = torch.tensor([weight_for_positive_class]).to(device) criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight) elif loss == "L1": criterion = torch.nn.L1Loss() else: raise NotImplementedError(f"{loss} is not implemented.") model = UNet(n_channels=3, n_classes=1, bilinear=False) model = model.to(device) # If we're not starting from scratch if load_model_parameters: path_model_parameters_to_load = os.path.join( dir_for_model_parameters, filename_model_parameters_to_load) model.load_state_dict(torch.load(path_model_parameters_to_load)) # If we're training or retraining a model if num_epochs > 0: if optimizer_type == "ADAM": optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) elif optimizer_type == "SGD": optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate) else: raise NotImplementedError(f"{optimizer} is not implemented.") scheduler = None if use_scheduler: scheduler = MultiStepLR(optimizer, milestones=milestones_scheduler, gamma=gamma_scheduler) avg_train_error, avg_validation_error = train( model, criterion, roof_dataloader_train, roof_dataloader_validation, optimizer, use_scheduler, scheduler, num_epochs, device, file_losses, saving_frequency, ) if save_model_parameters: path_model_parameters_to_save = os.path.join( dir_for_model_parameters, filename_model_parameters_to_save) torch.save(model.state_dict(), path_model_parameters_to_save) print(avg_train_error, avg_validation_error) return model, avg_train_error, avg_validation_error
if __name__ == "__main__": os.environ["CUDA_VISIBLE_DEVICES"] = "0" # 选择设备,有cuda用cuda,没有就用cpu device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') patch_size = 44 batch_size = 64 num_class = 13 epochs = 250 # 加载数据集 data_dir = "/home/cym/Datasets/StData-12/F3_block/" dataset = F3DS(data_dir, ptsize=patch_size, train=True) train_loader = torch.utils.data.DataLoader(dataset=dataset, batch_size=batch_size, shuffle=True) # 加载网络,图片单通道1,分类为13。 net = UNet(n_channels=1, n_classes=num_class, bilinear=False) net.apply(weight_init) # 将网络拷贝到deivce中 net.to(device=device) # 指定训练集地址,开始训练 writer = SummaryWriter('./logs2') train_net(net, train_loader, device, writer, epochs, save_dir="./models2")
def set_model(sample, device, args, train=True, experiment_name=None, finetune=False, model_path=None): # Create a model and criterion num_classes = 2 if 'unet' in args.model: if args.model == 'unet': model = UNet(n_channels=sample.shape[0], n_classes=num_classes, n_blocks=args.n_blocks, start_channels=args.start_channels, pos_loc=args.pos_loc, pos_dim=args.pos_dim, bilinear=args.bilinear, batch_size=args.batch_size) elif args.model == 'attn_unet': model = AttentionUNet(n_channels=sample.shape[0], n_classes=num_classes, n_blocks=args.n_blocks, start_channels=args.start_channels, pos_loc=args.pos_loc, pos_dim=args.pos_dim, bilinear=args.bilinear, batch_size=args.batch_size) elif args.model == 'suc_unet': model = SuccessiveUNet(n_channels=sample.shape[0], n_classes=num_classes, n_blocks=args.n_blocks, start_channels=args.start_channels, pos_loc=args.pos_loc, pos_dim=args.pos_dim, bilinear=args.bilinear, batch_size=args.batch_size) criterion = NIMSCrossEntropyLoss(args=args, device=device, num_classes=num_classes, use_weights=args.cross_entropy_weight, experiment_name=experiment_name) elif args.model == 'convlstm': # assert args.window_size == args.target_num, \ # 'window_size and target_num must be same for ConvLSTM' model = EncoderForecaster(input_channels=sample.shape[1], hidden_dim=args.hidden_dim, num_classes=num_classes) # criterion = MSELoss() criterion = NIMSCrossEntropyLoss(args=args, device=device, num_classes=num_classes, use_weights=args.cross_entropy_weight, experiment_name=experiment_name) if finetune: checkpoint = torch.load(model_path) model.load_state_dict(checkpoint, strict=False) # model = DataParallel(model) return model, criterion
import argparse import os import torch from model.unet_model import UNet from train import train torch.cuda.set_device(0) parser = argparse.ArgumentParser() unet = UNet() use_CUDA = True if __name__ == "__main__": parser.add_argument("-v", "--visual", action="store_true") parser.add_argument("-l", "--lr", type=float, default=1e-5) parser.add_argument("-e", "--epochs", type=int, default=10) parser.add_argument("-b", "--batch", type=int, default=40) parser.add_argument("-r", "--retrain", type=bool, default=False) args = parser.parse_args() if args.visual: from visual import show_pred_mask from loader import train_loader trl = train_loader(1, shuffle=True) img, msk = next(trl) unet.load_state_dict(torch.load("unet.pkl")) show_pred_mask(unet, img, msk) else:
class Meta(nn.Module): """ Meta Learner """ def __init__(self, args, writer, device): """ :param args: """ super(Meta, self).__init__() self.device = device self.update_lr = args.update_lr self.meta_lr = args.meta_lr self.n_way = args.n_way self.k_spt = args.k_spt self.k_qry = args.k_qry self.task_num = args.task_num self.update_step = args.update_step self.update_step_test = args.update_step_test self.writer = writer self.in_channels = args.imgc self.out_channels = args.output_channel self.epsilon = 1e-10 self.net = UNet(in_channels=args.imgc, out_channels=args.output_channel) self.net_param = [] for param in self.net.parameters(): self.net_param.append(param.clone().data.to(device)) logging.info(f'Network:\n' f'\t{self.net.in_channels} input channels\n' f'\t{self.net.out_channels} output channels (classes)') if args.load: net.load_state_dict(torch.load(args.load, map_location=device)) logging.info(f'Model loaded from {args.load}') # define loss self.mse_loss_fn = torch.nn.MSELoss(reduction='none') self.ssim_loss_fc = pytorch_msssim.SSIM(window_size=7) self.contour_loss = Contour_loss(K=5) # define optimizer for meta learning if args.optimizer == "adam": self.meta_optim = optim.Adam(self.net.parameters(), lr=self.meta_lr) elif args.optimizer == "rmsprop": self.meta_optim = optim.RMSprop(self.net.parameters(), lr=self.meta_lr, weight_decay=self.weight_decay) else: raise ValueError("Wrong Optimzer !") def forward(self, sun_imgs_x_spt, sun_imgs_y_spt, sun_imgs_x_qry, sun_imgs_y_qry, step): """ :param x_spt: [b, setsz, c_, h, w] :param y_spt: [b, setsz] :param x_qry: [b, querysz, c_, h, w] :param y_qry: [b, querysz] :return: """ task_num, setsz, c_, h, w = sun_imgs_x_spt.size() querysz = sun_imgs_x_qry.size(1) loss_q = [0 for _ in range(self.update_step + 1) ] # losses_q[i] is the loss on step i # set epsilon epsilon = 1e-10 alpha = 0.12 self.net = self.net.to(self.device) self.net.train() for i in range(task_num): reflection_pred = self.net(sun_imgs_x_spt[i]) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_spt[i][:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_spt[i][:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_spt[i][:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_spt[i][:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_spt[i]) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_spt[i], weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss grad = torch.autograd.grad(loss, self.net.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) # ===================================================== reflection_pred = self.net(sun_imgs_x_qry[i]) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_qry[i][:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_qry[i][:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_qry[i][:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_qry[i][:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_qry[i]) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_qry[i], weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss loss_q[0] += loss for each_fastweight, (name, param) in zip(fast_weights, self.net.named_parameters()): param.data = each_fastweight.data # ===================================================== reflection_pred = self.net(sun_imgs_x_qry[i]) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_qry[i][:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_qry[i][:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_qry[i][:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_qry[i][:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_qry[i]) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_qry[i], weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss loss_q[1] += loss del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map # ===================================================== for k in range(1, self.update_step): reflection_pred = self.net(sun_imgs_x_spt[i]) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_spt[i][:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_spt[i][:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_spt[i][:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_spt[i][:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_spt[i]) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_spt[i], weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss grad = torch.autograd.grad(loss, self.net.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) for each_fastweight, (name, param) in zip( fast_weights, self.net.named_parameters()): param.data = each_fastweight.data reflection_pred = self.net(sun_imgs_x_qry[i]) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_qry[i][:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_qry[i][:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_qry[i][:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_qry[i][:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_qry[i]) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_qry[i], weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss loss_q[k + 1] += loss del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map for net_param, (name, param) in zip(self.net_param, self.net.named_parameters()): param.data = net_param # end of all tasks # sum over all losses on query set across all tasks loss_q = torch.sum(torch.stack(loss_q)) / task_num # optimize theta parameters self.meta_optim.zero_grad() loss_q.backward() # optimize self.meta_optim.step() self.net_param = [] for param in self.net.parameters(): self.net_param.append(param.clone().data) torch.cuda.empty_cache() return loss_q def finetunning(self, sun_imgs_x_spt, sun_imgs_y_spt, sun_imgs_x_qry, sun_imgs_y_qry, step): """ :param x_spt: [setsz, c_, h, w] :param y_spt: [setsz] :param x_qry: [setsz, c_, h, w] :param y_qry: [querysz] :return: """ setsz, c_, h, w = sun_imgs_x_spt.size() querysz = sun_imgs_x_qry.size(1) loss_q = [0 for _ in range(self.update_step + 1) ] # losses_q[i] is the loss on step i # set epsilon epsilon = 1e-10 alpha = 0.12 self.net = self.net.to(self.device) self.net.train() # 1. run the i-th task and compute loss for k=0 reflection_pred = self.net(sun_imgs_x_spt) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_spt[:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div(sun_imgs_x_spt[:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div(sun_imgs_x_spt[:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div(sun_imgs_x_spt[:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_spt) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_spt, weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * (1 - torch.mean(ssim_loss) ) #+ REGULARIZATION * reg_loss grad = torch.autograd.grad(loss, self.net.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) with torch.no_grad(): # 1. run the i-th task and compute loss for k=0 reflection_pred = self.net(sun_imgs_x_qry) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_qry[:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_qry[:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_qry[:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_qry[:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_qry) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_qry, weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss loss_q[0] += loss for each_fastweight, (name, param) in zip(fast_weights, self.net.named_parameters()): param = each_fastweight with torch.no_grad(): # 1. run the i-th task and compute loss for k=0 # not use original net, use copy one reflection_pred = self.net(sun_imgs_x_qry) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_qry[:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_qry[:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_qry[:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_qry[:, 2, :, :], reflection_pred[:, 0, :, :]) weight_map = self.contour_loss.forward(sun_imgs_y_qry) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_qry, weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss loss_q[1] += loss del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map for k in range(1, self.update_step_test): reflection_pred = self.net(sun_imgs_x_spt) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_spt[:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_spt[:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_spt[:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_spt[:, 2, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred = torch.clamp(reflection_pred, 0, 1) weight_map = self.contour_loss.forward(sun_imgs_y_spt) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_spt, weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss grad = torch.autograd.grad(loss, self.parameters()) fast_weights = list( map(lambda p: p[1] - self.update_lr * p[0], zip(grad, self.net.parameters()))) # this is the loss and accuracy before first update for each_fastweight, (name, param) in zip(fast_weights, self.net.named_parameters()): param = each_fastweight with torch.no_grad(): reflection_pred = self.net(sun_imgs_x_qry) reflection_pred = reflection_pred + epsilon reflection_pred = torch.clamp( reflection_pred, 0.1, 5.0) # Peter: may be we can try to remove this item restoration_imgs_pred = torch.zeros( *sun_imgs_x_qry[:, :, :, :].shape).to(self.device) restoration_imgs_pred[:, 0, :, :] = torch.div( sun_imgs_x_qry[:, 0, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 1, :, :] = torch.div( sun_imgs_x_qry[:, 1, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred[:, 2, :, :] = torch.div( sun_imgs_x_qry[:, 2, :, :], reflection_pred[:, 0, :, :]) restoration_imgs_pred = torch.clamp(reflection_pred, 0, 1) weight_map = self.contour_loss.forward(sun_imgs_y_qry) restoration_imgs_pred = torch.mul(restoration_imgs_pred, weight_map) gt_imgs = torch.mul(sun_imgs_y_qry, weight_map) mse_loss = self.mse_loss_fn(restoration_imgs_pred, gt_imgs) ssim_loss = self.ssim_loss_fc(restoration_imgs_pred, gt_imgs) loss = torch.mean(mse_loss) + alpha * ( 1 - torch.mean(ssim_loss)) #+ REGULARIZATION * reg_loss loss_q[k + 1] += loss del grad, fast_weights, loss, mse_loss, ssim_loss, weight_map # this is the loss and accuracy before first update for net_param, (name, param) in zip(self.net_param, self.net.named_parameters()): param = net_param # end of all tasks # sum over all losses on query set across all tasks loss_q = torch.sum(torch.stack(loss_q)) torch.cuda.empty_cache() return loss_q