def __init__(self, args): self.args = args # Define Dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} _, _, self.test_loader, self.nclass = make_data_loader(args, **kwargs) self.model = None # Define network if self.args.backbone == 'unet': self.model = UNet(in_channels=4, n_classes=self.nclass) print("using UNet") if self.args.backbone == 'unetNested': self.model = UNetNested(in_channels=4, n_classes=self.nclass) print("using UNetNested") # Using cuda if args.cuda: self.model = self.model.cuda() if not os.path.isfile(args.checkpoint_file): raise RuntimeError("=> no checkpoint found at '{}'".format( args.checkpoint_file)) checkpoint = torch.load(args.checkpoint_file) self.model.load_state_dict(checkpoint['state_dict']) print("=> loaded checkpoint '{}'".format(args.checkpoint_file))
def __init__(self, args): self.args = args self.nclass = 16 # Define network self.unet_model = UNet(in_channels=4, n_classes=self.nclass) self.unetNested_model = UNetNested(in_channels=4, n_classes=self.nclass) self.combine_net_model = CombineNet(in_channels=192, n_classes=self.nclass) # Using cuda if args.cuda: self.unet_model = self.unet_model.cuda() self.unetNested_model = self.unetNested_model.cuda() self.combine_net_model = self.combine_net_model.cuda() # Load Unet model if not os.path.isfile(args.unet_checkpoint_file): raise RuntimeError("=> no unet checkpoint found at '{}'".format( args.unet_checkpoint_file)) checkpoint = torch.load(args.unet_checkpoint_file) self.unet_model.load_state_dict(checkpoint['state_dict']) print("=> loaded unet checkpoint '{}'".format( args.unet_checkpoint_file)) # Load UNetNested model if not os.path.isfile(args.unetNested_checkpoint_file): raise RuntimeError( "=> no UNetNested checkpoint found at '{}'".format( args.unetNested_checkpoint_file)) checkpoint = torch.load(args.unetNested_checkpoint_file) self.unetNested_model.load_state_dict(checkpoint['state_dict']) print("=> loaded UNetNested checkpoint '{}'".format( args.unetNested_checkpoint_file)) # Load Combine Net if not os.path.isfile(args.combine_net_checkpoint_file): raise RuntimeError( "=> no combine net checkpoint found at '{}'".format( args.combine_net_checkpoint_file)) checkpoint = torch.load(args.combine_net_checkpoint_file) self.combine_net_model.load_state_dict(checkpoint['state_dict']) print("=> loaded combine net checkpoint '{}'".format( args.combine_net_checkpoint_file))
def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # Define Dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( args, **kwargs) model = None # Define network if self.args.backbone == 'unet': model = UNet(in_channels=4, n_classes=self.nclass, sync_bn=args.sync_bn) print("using UNet") if self.args.backbone == 'unetNested': model = UNetNested(in_channels=4, n_classes=self.nclass, sync_bn=args.sync_bn) print("using UNetNested") # train_params = [{'params': model.get_params(), 'lr': args.lr}] train_params = [{'params': model.get_params()}] # Define Optimizer # optimizer = torch.optim.SGD(train_params, momentum=args.momentum, # weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer = torch.optim.Adam(train_params, self.args.learn_rate, weight_decay=args.weight_decay, amsgrad=True) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: classes_weights_path = os.path.join( Path.db_root_dir(args.dataset), args.dataset + '_classes_weights.npy') if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.model, self.optimizer = model, optimizer # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler # self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) self.model = self.model.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0
def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() self.nclass = 16 # Define network self.unet_model = UNet(in_channels=4, n_classes=self.nclass) self.unetNested_model = UNetNested(in_channels=4, n_classes=self.nclass) self.combine_net_model = CombineNet(in_channels=192, n_classes=self.nclass) train_params = [{'params': self.combine_net_model.get_params()}] # Define Optimizer self.optimizer = torch.optim.Adam(train_params, self.args.learn_rate, weight_decay=args.weight_decay, amsgrad=True) self.criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) # Define Evaluator self.evaluator = Evaluator(self.nclass) # Using cuda if args.cuda: self.unet_model = self.unet_model.cuda() self.unetNested_model = self.unetNested_model.cuda() self.combine_net_model = self.combine_net_model.cuda() # Load Unet checkpoint if not os.path.isfile(args.unet_checkpoint_file): raise RuntimeError("=> no Unet checkpoint found at '{}'".format( args.unet_checkpoint_file)) checkpoint = torch.load(args.unet_checkpoint_file) self.unet_model.load_state_dict(checkpoint['state_dict']) print("=> loaded Unet checkpoint '{}'".format( args.unet_checkpoint_file)) # Load UNetNested checkpoint if not os.path.isfile(args.unetNested_checkpoint_file): raise RuntimeError( "=> no UNetNested checkpoint found at '{}'".format( args.unetNested_checkpoint_file)) checkpoint = torch.load(args.unetNested_checkpoint_file) self.unetNested_model.load_state_dict(checkpoint['state_dict']) print("=> loaded UNetNested checkpoint '{}'".format( args.unetNested_checkpoint_file)) # Resuming combineNet checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError( "=> no combineNet checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.combine_net_model.module.load_state_dict( checkpoint['state_dict']) else: self.combine_net_model.load_state_dict( checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded combineNet checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0
class Trainer(object): def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() self.nclass = 16 # Define network self.unet_model = UNet(in_channels=4, n_classes=self.nclass) self.unetNested_model = UNetNested(in_channels=4, n_classes=self.nclass) self.combine_net_model = CombineNet(in_channels=192, n_classes=self.nclass) train_params = [{'params': self.combine_net_model.get_params()}] # Define Optimizer self.optimizer = torch.optim.Adam(train_params, self.args.learn_rate, weight_decay=args.weight_decay, amsgrad=True) self.criterion = SegmentationLosses( weight=None, cuda=args.cuda).build_loss(mode=args.loss_type) # Define Evaluator self.evaluator = Evaluator(self.nclass) # Using cuda if args.cuda: self.unet_model = self.unet_model.cuda() self.unetNested_model = self.unetNested_model.cuda() self.combine_net_model = self.combine_net_model.cuda() # Load Unet checkpoint if not os.path.isfile(args.unet_checkpoint_file): raise RuntimeError("=> no Unet checkpoint found at '{}'".format( args.unet_checkpoint_file)) checkpoint = torch.load(args.unet_checkpoint_file) self.unet_model.load_state_dict(checkpoint['state_dict']) print("=> loaded Unet checkpoint '{}'".format( args.unet_checkpoint_file)) # Load UNetNested checkpoint if not os.path.isfile(args.unetNested_checkpoint_file): raise RuntimeError( "=> no UNetNested checkpoint found at '{}'".format( args.unetNested_checkpoint_file)) checkpoint = torch.load(args.unetNested_checkpoint_file) self.unetNested_model.load_state_dict(checkpoint['state_dict']) print("=> loaded UNetNested checkpoint '{}'".format( args.unetNested_checkpoint_file)) # Resuming combineNet checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError( "=> no combineNet checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.combine_net_model.module.load_state_dict( checkpoint['state_dict']) else: self.combine_net_model.load_state_dict( checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded combineNet checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0 def training(self, epoch): print('[Epoch: %d, previous best = %.4f]' % (epoch, self.best_pred)) train_loss = 0.0 self.combine_net_model.train() self.evaluator.reset() num_img_tr = len(train_files) tbar = tqdm(train_files, desc='\r') for i, filename in enumerate(tbar): image = Image.open(os.path.join(train_dir, filename)) label = Image.open( os.path.join( train_label_dir, os.path.basename(filename)[:-4] + '_labelTrainIds.png')) label = np.array(label).astype(np.float32) label = label.reshape((1, 400, 400)) label = torch.from_numpy(label).float() label = label.cuda() # UNet_multi_scale_predict unt_pred = self.unet_multi_scale_predict(image) # UNetNested_multi_scale_predict unetnested_pred = self.unetnested_multi_scale_predict(image) net_input = torch.cat([unt_pred, unetnested_pred], 1) self.optimizer.zero_grad() output = self.combine_net_model(net_input) loss = self.criterion(output, label) loss.backward() self.optimizer.step() train_loss += loss.item() tbar.set_description('Train loss: %.5f' % (train_loss / (i + 1))) self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_img_tr * epoch) pred = output.data.cpu().numpy() label = label.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(label, pred) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar('train/mIoU', mIoU, epoch) self.writer.add_scalar('train/Acc', Acc, epoch) self.writer.add_scalar('train/Acc_class', Acc_class, epoch) self.writer.add_scalar('train/fwIoU', FWIoU, epoch) self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) print('train validation:') print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc, Acc_class, mIoU, FWIoU)) print('Loss: %.3f' % train_loss) print('---------------------------------') def validation(self, epoch): test_loss = 0.0 self.combine_net_model.eval() self.evaluator.reset() tbar = tqdm(val_files, desc='\r') num_img_val = len(val_files) for i, filename in enumerate(tbar): image = Image.open(os.path.join(val_dir, filename)) label = Image.open( os.path.join( val_label_dir, os.path.basename(filename)[:-4] + '_labelTrainIds.png')) label = np.array(label).astype(np.float32) label = label.reshape((1, 400, 400)) label = torch.from_numpy(label).float() label = label.cuda() # UNet_multi_scale_predict unt_pred = self.unet_multi_scale_predict(image) # UNetNested_multi_scale_predict unetnested_pred = self.unetnested_multi_scale_predict(image) net_input = torch.cat([unt_pred, unetnested_pred], 1) with torch.no_grad(): output = self.combine_net_model(net_input) loss = self.criterion(output, label) test_loss += loss.item() tbar.set_description('Test loss: %.5f' % (test_loss / (i + 1))) self.writer.add_scalar('val/total_loss_iter', loss.item(), i + num_img_val * epoch) pred = output.data.cpu().numpy() label = label.cpu().numpy() pred = np.argmax(pred, axis=1) # Add batch sample into evaluator self.evaluator.add_batch(label, pred) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) self.writer.add_scalar('val/mIoU', mIoU, epoch) self.writer.add_scalar('val/Acc', Acc, epoch) self.writer.add_scalar('val/Acc_class', Acc_class, epoch) self.writer.add_scalar('val/fwIoU', FWIoU, epoch) print('test validation:') print("Acc:{}, Acc_class:{}, mIoU:{}, fwIoU: {}".format( Acc, Acc_class, mIoU, FWIoU)) print('Loss: %.3f' % test_loss) print('====================================') new_pred = mIoU if new_pred > self.best_pred: is_best = True self.best_pred = new_pred self.saver.save_checkpoint( { 'epoch': epoch + 1, 'state_dict': self.combine_net_model.state_dict(), 'optimizer': self.optimizer.state_dict(), 'best_pred': self.best_pred, }, is_best) def unet_multi_scale_predict(self, image_ori: Image): self.unet_model.eval() # 预测原图 sample_ori = image_ori.copy() output_ori = self.unet_predict(sample_ori) # 预测旋转三个 angle_list = [90, 180, 270] for angle in angle_list: img_rotate = image_ori.rotate(angle, Image.BILINEAR) output = self.unet_predict(img_rotate) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) pred = cv2.warpAffine(pred, m_rotate, (400, 400)) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测竖直翻转 img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) output = self.unet_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 0) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测水平翻转 img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) output = self.unet_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 1) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) return output_ori def unet_predict(self, img: Image) -> torch.Tensor: img = self.transform_test(img) if self.args.cuda: img = img.cuda() with torch.no_grad(): output = self.unet_model(img) return output def unetnested_predict(self, img: Image) -> torch.Tensor: img = self.transform_test(img) if self.args.cuda: img = img.cuda() with torch.no_grad(): output = self.unetNested_model(img) return output def unetnested_multi_scale_predict(self, image_ori: Image): self.unetNested_model.eval() # 预测原图 sample_ori = image_ori.copy() output_ori = self.unetnested_predict(sample_ori) # 预测旋转三个 angle_list = [90, 180, 270] for angle in angle_list: img_rotate = image_ori.rotate(angle, Image.BILINEAR) output = self.unetnested_predict(img_rotate) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) pred = cv2.warpAffine(pred, m_rotate, (400, 400)) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测竖直翻转 img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) output = self.unetnested_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 0) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测水平翻转 img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) output = self.unetnested_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 1) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) return output_ori @staticmethod def transform_test(img): # Normalize mean = (0.544650, 0.352033, 0.384602, 0.352311) std = (0.249456, 0.241652, 0.228824, 0.227583) img = np.array(img).astype(np.float32) img /= 255.0 img -= mean img /= std # ToTensor img = img.transpose((2, 0, 1)) img = np.array([ img, ]) img = torch.from_numpy(img).float() return img
class Visualization: def __init__(self, args): self.args = args self.nclass = 16 # Define network self.unet_model = UNet(in_channels=4, n_classes=self.nclass) self.unetNested_model = UNetNested(in_channels=4, n_classes=self.nclass) self.combine_net_model = CombineNet(in_channels=192, n_classes=self.nclass) # Using cuda if args.cuda: self.unet_model = self.unet_model.cuda() self.unetNested_model = self.unetNested_model.cuda() self.combine_net_model = self.combine_net_model.cuda() # Load Unet model if not os.path.isfile(args.unet_checkpoint_file): raise RuntimeError("=> no unet checkpoint found at '{}'".format( args.unet_checkpoint_file)) checkpoint = torch.load(args.unet_checkpoint_file) self.unet_model.load_state_dict(checkpoint['state_dict']) print("=> loaded unet checkpoint '{}'".format( args.unet_checkpoint_file)) # Load UNetNested model if not os.path.isfile(args.unetNested_checkpoint_file): raise RuntimeError( "=> no UNetNested checkpoint found at '{}'".format( args.unetNested_checkpoint_file)) checkpoint = torch.load(args.unetNested_checkpoint_file) self.unetNested_model.load_state_dict(checkpoint['state_dict']) print("=> loaded UNetNested checkpoint '{}'".format( args.unetNested_checkpoint_file)) # Load Combine Net if not os.path.isfile(args.combine_net_checkpoint_file): raise RuntimeError( "=> no combine net checkpoint found at '{}'".format( args.combine_net_checkpoint_file)) checkpoint = torch.load(args.combine_net_checkpoint_file) self.combine_net_model.load_state_dict(checkpoint['state_dict']) print("=> loaded combine net checkpoint '{}'".format( args.combine_net_checkpoint_file)) def visualization(self): self.combine_net_model.eval() tbar = tqdm(test_files, desc='\r') for i, filename in enumerate(tbar): image = Image.open(os.path.join(test_dir, filename)) # UNet_multi_scale_predict unt_pred = self.unet_multi_scale_predict(image) # UNetNested_multi_scale_predict unetnested_pred = self.unetnested_multi_scale_predict(image) net_input = torch.cat([unt_pred, unetnested_pred], 1) with torch.no_grad(): output = self.combine_net_model(net_input) pred = output.data.cpu().numpy()[0] pred = np.argmax(pred, axis=0) rgb = decode_segmap(pred, self.args.dataset) pred_img = Image.fromarray(pred, mode='L') rgb_img = Image.fromarray(rgb, mode='RGB') pred_img.save( os.path.join(self.args.vis_logdir, 'raw_train_id', filename)) rgb_img.save( os.path.join(self.args.vis_logdir, 'vis_color', filename)) def unet_multi_scale_predict(self, image_ori: Image): self.unet_model.eval() # 预测原图 sample_ori = image_ori.copy() output_ori = self.unet_predict(sample_ori) # 预测旋转三个 angle_list = [90, 180, 270] for angle in angle_list: img_rotate = image_ori.rotate(angle, Image.BILINEAR) output = self.unet_predict(img_rotate) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) pred = cv2.warpAffine(pred, m_rotate, (400, 400)) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测竖直翻转 img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) output = self.unet_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 0) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测水平翻转 img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) output = self.unet_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 1) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) return output_ori def unet_predict(self, img: Image) -> torch.Tensor: img = self.transform_test(img) if self.args.cuda: img = img.cuda() with torch.no_grad(): output = self.unet_model(img) return output def unetnested_predict(self, img: Image) -> torch.Tensor: img = self.transform_test(img) if self.args.cuda: img = img.cuda() with torch.no_grad(): output = self.unetNested_model(img) return output def unetnested_multi_scale_predict(self, image_ori: Image): self.unetNested_model.eval() # 预测原图 sample_ori = image_ori.copy() output_ori = self.unetnested_predict(sample_ori) # 预测旋转三个 angle_list = [90, 180, 270] for angle in angle_list: img_rotate = image_ori.rotate(angle, Image.BILINEAR) output = self.unetnested_predict(img_rotate) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) m_rotate = cv2.getRotationMatrix2D((200, 200), 360.0 - angle, 1) pred = cv2.warpAffine(pred, m_rotate, (400, 400)) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测竖直翻转 img_flip = image_ori.transpose(Image.FLIP_TOP_BOTTOM) output = self.unetnested_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 0) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) # 预测水平翻转 img_flip = image_ori.transpose(Image.FLIP_LEFT_RIGHT) output = self.unetnested_predict(img_flip) pred = output.data.cpu().numpy()[0] pred = pred.transpose((1, 2, 0)) pred = cv2.flip(pred, 1) pred = pred.transpose((2, 0, 1)) output = torch.from_numpy(np.array([ pred, ])).float() output_ori = torch.cat([output_ori, output.cuda()], 1) return output_ori @staticmethod def transform_test(img): # Normalize mean = (0.544650, 0.352033, 0.384602, 0.352311) std = (0.249456, 0.241652, 0.228824, 0.227583) img = np.array(img).astype(np.float32) img /= 255.0 img -= mean img /= std # ToTensor img = img.transpose((2, 0, 1)) img = np.array([ img, ]) img = torch.from_numpy(img).float() return img
@function: 读取单独保存的模型参数,将其与模型结构一起重新保存 @author:HuiYi or 会意 @file: vis.py.py @time: 2019/7/30 下午7:00 """ import torch from models.backbone.UNet import UNet model_path_list = [ '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/checkpoint.pth.tar', '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_1/checkpoint.pth.tar', '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_2/checkpoint.pth.tar' ] if __name__ == '__main__': model = UNet(in_channels=4, n_classes=16, sync_bn=False) model = model.cuda() param = '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/checkpoint.pth.tar' checkpoint = torch.load(param) model.load_state_dict(checkpoint['state_dict']) torch.save( model, '/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_0/model_and_param.pth.tar' ) print('save finish') # load # model = torch.load('/home/lab/ygy/rssrai2019/rssrai2019_semantic_segmentation/run/rssrai2019/unet/experiment_1/model_and_param.pth.tar') # params = model.state_dict() # print('load')