def __init__(self): #GPU assignment os.environ["CUDA_VISIBLE_DEVICES"] = "0" self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') #Load checkpoint self.checkpoint = torch.load( os.path.join("./src/deeplab_ros/data/model_best.pth.tar")) #Load Model self.model = DeepLab(num_classes=4, backbone='mobilenet', output_stride=16, sync_bn=True, freeze_bn=False) self.model.load_state_dict(self.checkpoint['state_dict']) self.model = self.model.to(self.device) #ROS init self.bridge = CvBridge() self.image_sub = rospy.Subscriber("/cam2/pylon_camera_node/image_raw", ImageMsg, self.callback, queue_size=1, buff_size=2**24) self.image_pub = rospy.Publisher("segmentation_image", ImageMsg, queue_size=1)
class RAN(): def __init__(self, weight, gpu_ids): self.model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=16) torch.cuda.set_device(gpu_ids) self.model = self.model.cuda() assert weight is not None if not os.path.isfile(weight): raise RuntimeError("=> no checkpoint found at '{}'".format(weight)) checkpoint = torch.load(weight) self.model.load_state_dict(checkpoint['state_dict']) self.model.eval() self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225) def inference(self, img): # normalize img = cv2.resize(img, (480, 480)) img = img.astype(np.float32) img /= 255.0 img -= self.mean img /= self.std img = img.transpose((2, 0, 1)) img = img[np.newaxis, :, :, :] # to tensor img = torch.from_numpy(img).float().cuda() with torch.no_grad(): output = self.model(img) return output
def load_model(model_path, num_classes=14, backbone='resnet', output_stride=16): print(f"Loading model from {model_path}") model = DeepLab(num_classes=num_classes, backbone=backbone, output_stride=output_stride) pretrained_dict = torch.load(model_path, map_location=lambda storage, loc: storage) model_dict = model.state_dict() # 1. filter out unnecessary keys and mismatching sizes pretrained_dict = {k: v for k, v in pretrained_dict.items() if (k in model_dict) and (model_dict[k].shape == pretrained_dict[k].shape)} # 2. overwrite entries in the existing state dict model_dict.update(pretrained_dict) # 3. load the new state dict model.load_state_dict(model_dict) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") if torch.cuda.device_count() > 1: print("Load model in ", torch.cuda.device_count(), " GPUs!") model = nn.DataParallel(model) model.to(device) model.eval() return model
def __init__(self, args): super(MPNet, self).__init__() self.enable_mp_layer = (args.mpnet_mrf_mode in ['TRWP', 'ISGMR', 'MeanField', 'SGM']) self.args = args BatchNorm = SynchronizedBatchNorm2d if args.sync_bn else nn.BatchNorm2d self.enable_score_scale = args.enable_score_scale self.deeplab = DeepLab( num_classes=args.n_classes, backbone=args.deeplab_backbone, output_stride=args.deeplab_outstride, sync_bn=args.deeplab_sync_bn, freeze_bn=args.deeplab_freeze_bn, enable_interpolation=args.deeplab_enable_interpolation, pretrained_path=args.resnet_pretrained_path, norm_layer=BatchNorm, enable_aspp=not self.args.disable_aspp) if self.enable_mp_layer: if self.args.mpnet_mrf_mode == 'TRWP': self.mp_layer = MPModule_TRWP(self.args, enable_create_label_context=True, enable_saving_label=False) elif self.args.mpnet_mrf_mode in {'ISGMR', 'SGM'}: self.mp_layer = MPModule_ISGMR( self.args, enable_create_label_context=True, enable_saving_label=False) elif self.args.mpnet_mrf_mode == 'MeanField': self.mp_layer = MeanField(self.args, enable_create_label_context=True) else: assert False
def __init__(self, weight, gpu_ids): self.model = DeepLab(num_classes=2, backbone='mobilenet', output_stride=16) torch.cuda.set_device(gpu_ids) self.model = self.model.cuda() assert weight is not None if not os.path.isfile(weight): raise RuntimeError("=> no checkpoint found at '{}'".format(weight)) checkpoint = torch.load(weight) self.model.load_state_dict(checkpoint['state_dict']) self.model.eval() self.mean = (0.485, 0.456, 0.406) self.std = (0.229, 0.224, 0.225)
def main(): args = arguments() seed(args) model = DeepLab(backbone='mobilenet', output_stride=16, num_classes=21, sync_bn=False) model.eval() from aimet_torch import batch_norm_fold from aimet_torch import utils args.input_shape = (1, 3, 513, 513) batch_norm_fold.fold_all_batch_norms(model, args.input_shape) utils.replace_modules_of_type1_with_type2(model, torch.nn.ReLU6, torch.nn.ReLU) if args.checkpoint_path: model.load_state_dict(torch.load(args.checkpoint_path)) else: raise ValueError('checkpoint path {} must be specified'.format( args.checkpoint_path)) data_loader_kwargs = {'worker_init_fn': work_init, 'num_workers': 0} train_loader, val_loader, test_loader, num_class = make_data_loader( args, **data_loader_kwargs) eval_func_quant = model_eval(args, val_loader) eval_func = model_eval(args, val_loader) from aimet_common.defs import QuantScheme from aimet_torch.quantsim import QuantizationSimModel if hasattr(args, 'quant_scheme'): if args.quant_scheme == 'range_learning_tf': quant_scheme = QuantScheme.training_range_learning_with_tf_init elif args.quant_scheme == 'range_learning_tfe': quant_scheme = QuantScheme.training_range_learning_with_tf_enhanced_init elif args.quant_scheme == 'tf': quant_scheme = QuantScheme.post_training_tf elif args.quant_scheme == 'tf_enhanced': quant_scheme = QuantScheme.post_training_tf_enhanced else: raise ValueError("Got unrecognized quant_scheme: " + args.quant_scheme) kwargs = { 'quant_scheme': quant_scheme, 'default_param_bw': args.default_param_bw, 'default_output_bw': args.default_output_bw, 'config_file': args.config_file } print(kwargs) sim = QuantizationSimModel(model.cpu(), input_shapes=args.input_shape, **kwargs) sim.compute_encodings(eval_func_quant, (1024, True)) post_quant_top1 = eval_func(sim.model.cuda(), (99999999, True)) print("Post Quant mIoU :", post_quant_top1)
def main(checkpoint_filename, input_image, output_image): device = torch.device( 'cuda') if torch.cuda.is_available() else torch.device('cpu') # Define network model = DeepLab(num_classes=3, backbone='resnet', output_stride=16, sync_bn=False, freeze_bn=False) checkpoint = torch.load(checkpoint_filename, map_location=device) state_dict = checkpoint['state_dict'] # because model was saved with DataParallel, stored checkpoint contains "module" prefix that we want to strip state_dict = { key[7:] if key.startswith('module.') else key: val for key, val in state_dict.items() } model.load_state_dict(state_dict) model.eval() image = Image.open(input_image).convert('RGB') mask = predict(model, image) mask.save(output_image)
def create_segmentation_models(encoder, arch, num_classes=4, encoder_weights=None, activation=None): ''' segmentation_models_pytorch https://github.com/qubvel/segmentation_models.pytorch has following architectures: - Unet - Linknet - FPN - PSPNet encoders: A lot! see the above github page. Deeplabv3+ https://github.com/jfzhang95/pytorch-deeplab-xception has for encoders: - resnet (resnet101) - mobilenet - xception - drn ''' if arch == "Unet": return smp.Unet(encoder, encoder_weights=encoder_weights, classes=num_classes, activation=activation) elif arch == "Linknet": return smp.Linknet(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "FPN": return smp.FPN(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "PSPNet": return smp.PSPNet(encoder, encoder_weights=encoder_weghts, classes=num_classes, activation=activation) elif arch == "deeplabv3plus": if deeplabv3plus_PATH in os.environ: sys.path.append(os.environ[deeplabv3plus_PATH]) from modeling.deeplab import DeepLab return DeepLab(encoder, num_classes=4) else: raise ValueError('Set deeplabv3plus path by environment variable.') else: raise ValueError( 'arch {} is not found, set the correct arch'.format(arch)) sys.exit()
def inference_A_sample_image(img_path, model_path, num_classes, backbone, output_stride, sync_bn, freeze_bn): # read image image = cv2.imread(img_path) # print(image.shape) image = np.array(image).astype(np.float32) # Normalize pascal image (mean and std is from pascal.py) mean = (0.485, 0.456, 0.406) std = (0.229, 0.224, 0.225) image /= 255 image -= mean image /= std # swap color axis because # numpy image: H x W x C # torch image: C X H X W image = image.transpose((2, 0, 1)) # to 4D, N=1 image = image.reshape(1, image.shape[0], image.shape[1], image.shape[2]) image = torch.from_numpy(image) #.float() model = DeepLab(num_classes=num_classes, backbone=backbone, output_stride=output_stride, sync_bn=sync_bn, freeze_bn=freeze_bn, pretrained=True) # False if torch.cuda.is_available() is False: device = torch.device('cpu') else: device = None # need added # checkpoint = torch.load(model_path,map_location=device) # model.load_state_dict(checkpoint['state_dict']) checkpoint = torch.load('resnet101-5d3b4d8f.pth', map_location=device) model.load_state_dict(checkpoint['state_dict']) # for set dropout and batch normalization layers to evaluation mode before running inference. # Failing to do this will yield inconsistent inference results. model.eval() with torch.no_grad(): output = model(image) out_np = output.cpu().data.numpy() pred = np.argmax(out_np, axis=1) pred = pred.reshape(pred.shape[1], pred.shape[2]) # save result cv2.imwrite('output.jpg', pred) test = 1
def main(args): vali_dataset = MRIBrainSegmentation(root_folder=args.root_folder, image_label=args.data_label, is_train=False) vali_loader = torch.utils.data.DataLoader(vali_dataset, batch_size=16, shuffle=False, num_workers=4, drop_last=False) # Init and load model model = DeepLab(num_classes=1, backbone='resnet', output_stride=8, sync_bn=None, freeze_bn=False) checkpoint = torch.load(args.checkpoint) state_dict = checkpoint['state_dict'] model.load_state_dict(state_dict) device = torch.device(args.device if torch.cuda.is_available() else 'cpu') model = model.to(device) model.eval() with torch.no_grad(): for i, sample in enumerate(vali_loader): print(i) data = sample['image'] target = sample['mask'] data, target = data.to(device), target.to(device) output = model(data) target = target.data.cpu().numpy() data = data.data.cpu().numpy() output = output.data.cpu().numpy() pred = np.zeros_like(output) pred[output > 0.5] = 1 pred = pred[:, 0] for j in range(len(target)): output_image = pred[j] * 255 target_image = target[j] * 255 cv2.imwrite("{}/{:06d}_{:06d}_predict.png".format(args.output_folder, i, j), output_image.astype(np.uint8)) cv2.imwrite("{}/{:06d}_{:06d}_target.png".format(args.output_folder, i, j), target_image.astype(np.uint8)) img = data[j].transpose([1, 2, 0]) img *= (0.229, 0.224, 0.225) img += (0.485, 0.456, 0.406) img *= 255.0 cv2.imwrite( "{}}/{:06d}_{:06d}_origin.png".format(args.output_folder, i, j), img.astype(np.uint8))
def test(args): kwargs = {'num_workers': 1, 'pin_memory': True} _, val_loader, _, nclass = make_data_loader(args, **kwargs) checkpoint = torch.load(args.ckpt) if checkpoint is None: raise ValueError device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') model = DeepLab(num_classes=nclass, backbone='resnet', output_stride=16, sync_bn=True, freeze_bn=False) model.load_state_dict(checkpoint['state_dict']) model.eval() model.to(device) torch.set_grad_enabled(False) tbar = tqdm(val_loader) num_img_tr = len(val_loader) for i, sample in enumerate(tbar): x1, x2, y1, y2 = [ int(item) for item in sample['img_meta']['bbox_coord'] ] # bbox coord w, h = x2 - x1, y2 - y1 img = sample['img_meta']['image'].squeeze().cpu().numpy() img_w, img_h = img.shape[:2] inputs = sample['image'].cuda() output = model(inputs).squeeze().cpu().numpy() pred = np.argmax(output, axis=0) result = decode_segmap(pred, dataset=args.dataset, plot=False) result = imresize(result, (w, h)) result_padding = np.zeros(img.shape, dtype=np.uint8) result_padding[y1:y2, x1:x2] = result result = img // 2 + result_padding * 127 result[result > 255] = 255 plt.imsave( os.path.join('run', args.dataset, 'deeplab-resnet', 'output', str(i)), result)
def __init__(self, config: BaseConfig): self._config = config self._model = DeepLab(num_classes=9, output_stride=8, sync_bn=False).to(self._config.device) self._border_loss = TotalLoss(self._config) self._direction_loss = CrossEntropyLoss() self._loaders = get_data_loaders(config) self._writer = SummaryWriter() self._optimizer = torch.optim.SGD(self._model.parameters(), lr=self._config.lr, weight_decay=1e-4, nesterov=True, momentum=0.9) self._scheduler = torch.optim.lr_scheduler.ExponentialLR( self._optimizer, gamma=0.97) if self._config.parallel: self._model = DistributedDataParallel(self._model, device_ids=[ self._config.device, ])
def test(args): kwargs = {'num_workers': 1, 'pin_memory': True} train_loader, val_loader, test_loader, nclass = make_data_loader(args, **kwargs) model = DeepLab(num_classes=nclass, backbone=args.backbone, output_stride=args.out_stride, sync_bn=False) model.load_state_dict(torch.load(args.pretrained, map_location=device)['state_dict']) model.eval() tbar = tqdm(test_loader) ## train test dev for i, sample in enumerate(tbar): image, target = sample['image'], sample['label'] # original_image = image if args.use_mixup: image, targets_a, targets_b, lam = mixup_data(image, target, args.mixup_alpha, use_cuda=False) # mixed_image = image # image = norm(image.permute(0,2,3,1)).permute(0,3,1,2) output = model(image)
def main(): """Create the model and start the evaluation process.""" args = get_arguments() if not os.path.exists(args.save): os.makedirs(args.save) if args.model == 'DeeplabMulti': model = DeeplabMulti(num_classes=args.num_classes) elif args.model == 'Oracle': #model = Res_Deeplab(num_classes=args.num_classes) model = DeepLab(backbone='resnet', output_stride=8) if args.restore_from == RESTORE_FROM: args.restore_from = RESTORE_FROM_ORC elif args.model == 'DeeplabVGG': model = DeeplabVGG(num_classes=args.num_classes) if args.restore_from == RESTORE_FROM: args.restore_from = RESTORE_FROM_VGG if args.restore_from[:4] == 'http': saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) ### for running different versions of pytorch model_dict = model.state_dict() saved_state_dict = { k: v for k, v in saved_state_dict.items() if k in model_dict } model_dict.update(saved_state_dict) model.load_state_dict(saved_state_dict) device = torch.device("cuda" if not args.cpu else "cpu") model = model.to(device) model.eval() num_classes = 20 tp_list = [0] * num_classes fp_list = [0] * num_classes fn_list = [0] * num_classes iou_list = [0] * num_classes hist = np.zeros((21, 21)) group = 1 scorer = SegScorer(num_classes=21) datalayer = SSDatalayer(group) cos_similarity_func = nn.CosineSimilarity() for count in tqdm(range(1000)): dat = datalayer.dequeue() ref_img = dat['second_img'][0] # (3, 457, 500) query_img = dat['first_img'][0] # (3, 375, 500) query_label = dat['second_label'][0] # (1, 375, 500) ref_label = dat['first_label'][0] # (1, 457, 500) # query_img = dat['second_img'][0] # ref_img = dat['first_img'][0] # ref_label = dat['second_label'][0] # query_label = dat['first_label'][0] deploy_info = dat['deploy_info'] semantic_label = deploy_info['first_semantic_labels'][0][0] - 1 # 2 ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor( ref_label).cuda() query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor( query_label[0, :, :]).cuda() #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label) #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :]) # ref_img = ref_img*ref_label ref_img_var, query_img_var = Variable(ref_img), Variable(query_img) query_label_var, ref_label_var = Variable(query_label), Variable( ref_label) ref_img_var = torch.unsqueeze(ref_img_var, dim=0) # [1, 3, 457, 500] ref_label_var = torch.unsqueeze(ref_label_var, dim=1) # [1, 1, 457, 500] query_img_var = torch.unsqueeze(query_img_var, dim=0) # [1, 3, 375, 500] query_label_var = torch.unsqueeze(query_label_var, dim=0) # [1, 375, 500] samples = torch.cat([ref_img_var, query_img_var], 0) pred = model(samples, ref_label_var) w, h = query_label.size() pred = F.upsample(pred, size=(w, h), mode='bilinear') #[2, 416, 416] pred = F.softmax(pred, dim=1).squeeze() values, pred = torch.max(pred, dim=0) #print(pred.shape) pred = pred.data.cpu().numpy().astype(np.int32) # (333, 500) #print(pred.shape) org_img = get_org_img( query_img.squeeze().cpu().data.numpy()) # 查询集的图片(375, 500, 3) #print(org_img.shape) img = mask_to_img(pred, org_img) # (375, 500, 3)mask和原图加权后的彩色图片 cv2.imwrite('save_bins/que_pred/query_set_1_%d.png' % (count), img) query_label = query_label.cpu().numpy().astype(np.int32) # (333, 500) class_ind = int(deploy_info['first_semantic_labels'][0][0] ) - 1 # because class indices from 1 in data layer,0 scorer.update(pred, query_label, class_ind + 1) tp, tn, fp, fn = measure(query_label, pred) # iou_img = tp/float(max(tn+fp+fn,1)) tp_list[class_ind] += tp fp_list[class_ind] += fp fn_list[class_ind] += fn # max in case both pred and label are zero iou_list = [ tp_list[ic] / float(max(tp_list[ic] + fp_list[ic] + fn_list[ic], 1)) for ic in range(num_classes) ] tmp_pred = pred tmp_pred[tmp_pred > 0.5] = class_ind + 1 tmp_gt_label = query_label tmp_gt_label[tmp_gt_label > 0.5] = class_ind + 1 hist += Metrics.fast_hist(tmp_pred, query_label, 21) print("-------------GROUP %d-------------" % (group)) print(iou_list) class_indexes = range(group * 5, (group + 1) * 5) print('Mean:', np.mean(np.take(iou_list, class_indexes))) ''' for group in range(2): datalayer = SSDatalayer(group+1) restore(args, model, group+1) for count in tqdm(range(1000)): dat = datalayer.dequeue() ref_img = dat['second_img'][0]#(3, 457, 500) query_img = dat['first_img'][0]#(3, 375, 500) query_label = dat['second_label'][0]#(1, 375, 500) ref_label = dat['first_label'][0]#(1, 457, 500) # query_img = dat['second_img'][0] # ref_img = dat['first_img'][0] # ref_label = dat['second_label'][0] # query_label = dat['first_label'][0] deploy_info = dat['deploy_info'] semantic_label = deploy_info['first_semantic_labels'][0][0] - 1#2 ref_img, ref_label = torch.Tensor(ref_img).cuda(), torch.Tensor(ref_label).cuda() query_img, query_label = torch.Tensor(query_img).cuda(), torch.Tensor(query_label[0,:,:]).cuda() #ref_img, ref_label = torch.Tensor(ref_img), torch.Tensor(ref_label) #query_img, query_label = torch.Tensor(query_img), torch.Tensor(query_label[0, :, :]) # ref_img = ref_img*ref_label ref_img_var, query_img_var = Variable(ref_img), Variable(query_img) query_label_var, ref_label_var = Variable(query_label), Variable(ref_label) ref_img_var = torch.unsqueeze(ref_img_var,dim=0)#[1, 3, 457, 500] ref_label_var = torch.unsqueeze(ref_label_var, dim=1)#[1, 1, 457, 500] query_img_var = torch.unsqueeze(query_img_var, dim=0)#[1, 3, 375, 500] query_label_var = torch.unsqueeze(query_label_var, dim=0)#[1, 375, 500] logits = model(query_img_var, ref_img_var, ref_label_var,ref_label_var) # w, h = query_label.size() # outB_side = F.upsample(outB_side, size=(w, h), mode='bilinear') # out_side = F.softmax(outB_side, dim=1).squeeze() # values, pred = torch.max(out_side, dim=0) values, pred = model.get_pred(logits, query_img_var)#values[2, 333, 500] pred = pred.data.cpu().numpy().astype(np.int32)#(333, 500) query_label = query_label.cpu().numpy().astype(np.int32)#(333, 500) class_ind = int(deploy_info['first_semantic_labels'][0][0])-1 # because class indices from 1 in data layer,0 scorer.update(pred, query_label, class_ind+1) tp, tn, fp, fn = measure(query_label, pred) # iou_img = tp/float(max(tn+fp+fn,1)) tp_list[class_ind] += tp fp_list[class_ind] += fp fn_list[class_ind] += fn # max in case both pred and label are zero iou_list = [tp_list[ic] / float(max(tp_list[ic] + fp_list[ic] + fn_list[ic],1)) for ic in range(num_classes)] tmp_pred = pred tmp_pred[tmp_pred>0.5] = class_ind+1 tmp_gt_label = query_label tmp_gt_label[tmp_gt_label>0.5] = class_ind+1 hist += Metrics.fast_hist(tmp_pred, query_label, 21) print("-------------GROUP %d-------------"%(group)) print(iou_list) class_indexes = range(group*5, (group+1)*5) print('Mean:', np.mean(np.take(iou_list, class_indexes))) print('BMVC IOU', np.mean(np.take(iou_list, range(0,20)))) miou = Metrics.get_voc_iou(hist) print('IOU:', miou, np.mean(miou)) ''' binary_hist = np.array((hist[0, 0], hist[0, 1:].sum(), hist[1:, 0].sum(), hist[1:, 1:].sum())).reshape((2, 2)) bin_iu = np.diag(binary_hist) / (binary_hist.sum(1) + binary_hist.sum(0) - np.diag(binary_hist)) print('Bin_iu:', bin_iu) scores = scorer.score() for k in scores.keys(): print(k, np.mean(scores[k]), scores[k])
from PIL import Image import numpy as np import torch import torchvision.transforms as tr from modeling.deeplab import DeepLab from dataloaders.utils import decode_segmap device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') checkpoint = torch.load("./run/pascal/deeplab-resnet/model_best.pth") model = DeepLab(num_classes=21, backbone='resnet', output_stride=16, sync_bn=True, freeze_bn=False) model.load_state_dict(checkpoint['state_dict_G']) model.eval() model.to(device) def transform(image): return tr.Compose([ tr.Resize(513), tr.CenterCrop(513), tr.ToTensor(), tr.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ])(image)
def __init__(self, data_train, data_valid, image_base_dir, instructions): """ :param data_train: :param data_valid: :param image_base_dir: :param instructions: """ self.image_base_dir = image_base_dir self.data_valid = data_valid self.instructions = instructions # specify model save dir self.model_name = instructions[STR.MODEL_NAME] # now = time.localtime() # start_time = "{}-{}-{}T{}:{}:{}".format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, # now.tm_sec) experiment_folder_path = os.path.join(paths.MODELS_FOLDER_PATH, self.model_name) if os.path.exists(experiment_folder_path): Warning( "Experiment folder exists already. Files might be overwritten") os.makedirs(experiment_folder_path, exist_ok=True) # define saver and save instructions self.saver = Saver(folder_path=experiment_folder_path, instructions=instructions) self.saver.save_instructions() # define Tensorboard Summary self.writer = SummaryWriter(log_dir=experiment_folder_path) nn_input_size = instructions[STR.NN_INPUT_SIZE] state_dict_file_path = instructions.get(STR.STATE_DICT_FILE_PATH, None) self.colour_mapping = mapping.get_colour_mapping() # define transformers for training crops_per_image = instructions.get(STR.CROPS_PER_IMAGE, 10) apply_random_cropping = (STR.CROPS_PER_IMAGE in instructions.keys()) and \ (STR.IMAGES_PER_BATCH in instructions.keys()) print("{}applying random cropping".format( "" if apply_random_cropping else "_NOT_ ")) t = [Normalize()] if apply_random_cropping: t.append( RandomCrop(min_size=instructions.get(STR.CROP_SIZE_MIN, 400), max_size=instructions.get(STR.CROP_SIZE_MAX, 1000), crop_count=crops_per_image)) t += [ Resize(nn_input_size), Flip(p_vertical=0.2, p_horizontal=0.5), ToTensor() ] transformations_train = transforms.Compose(t) # define transformers for validation transformations_valid = transforms.Compose( [Normalize(), Resize(nn_input_size), ToTensor()]) # set up data loaders dataset_train = DictArrayDataSet(image_base_dir=image_base_dir, data=data_train, num_classes=len( self.colour_mapping.keys()), transformation=transformations_train) # define batch sizes self.batch_size = instructions[STR.BATCH_SIZE] if apply_random_cropping: self.data_loader_train = DataLoader( dataset=dataset_train, batch_size=instructions[STR.IMAGES_PER_BATCH], shuffle=True, collate_fn=custom_collate) else: self.data_loader_train = DataLoader(dataset=dataset_train, batch_size=self.batch_size, shuffle=True, collate_fn=custom_collate) dataset_valid = DictArrayDataSet(image_base_dir=image_base_dir, data=data_valid, num_classes=len( self.colour_mapping.keys()), transformation=transformations_valid) self.data_loader_valid = DataLoader(dataset=dataset_valid, batch_size=self.batch_size, shuffle=False, collate_fn=custom_collate) self.num_classes = dataset_train.num_classes() # define model print("Building model") self.model = DeepLab(num_classes=self.num_classes, backbone=instructions.get(STR.BACKBONE, "resnet"), output_stride=instructions.get( STR.DEEPLAB_OUTPUT_STRIDE, 16)) # load weights if state_dict_file_path is not None: print("loading state_dict from:") print(state_dict_file_path) load_state_dict(self.model, state_dict_file_path) learning_rate = instructions.get(STR.LEARNING_RATE, 1e-5) train_params = [{ 'params': self.model.get_1x_lr_params(), 'lr': learning_rate }, { 'params': self.model.get_10x_lr_params(), 'lr': learning_rate }] # choose gpu or cpu self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") if instructions.get(STR.MULTI_GPU, False): if torch.cuda.device_count() > 1: print("Using ", torch.cuda.device_count(), " GPUs!") self.model = nn.DataParallel(self.model) self.model.to(self.device) # Define Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=0.9, weight_decay=5e-4, nesterov=False) # calculate class weights if instructions.get(STR.CLASS_STATS_FILE_PATH, None): class_weights = calculate_class_weights( instructions[STR.CLASS_STATS_FILE_PATH], self.colour_mapping, modifier=instructions.get(STR.LOSS_WEIGHT_MODIFIER, 1.01)) class_weights = torch.from_numpy(class_weights.astype(np.float32)) else: class_weights = None self.criterion = SegmentationLosses( weight=class_weights, cuda=self.device.type != "cpu").build_loss() # Define Evaluator self.evaluator = Evaluator(self.num_classes) # Define lr scheduler self.scheduler = None if instructions.get(STR.USE_LR_SCHEDULER, True): self.scheduler = LR_Scheduler(mode="cos", base_lr=learning_rate, num_epochs=instructions[STR.EPOCHS], iters_per_epoch=len( self.data_loader_train)) # print information before training start print("-" * 60) print("instructions") pprint(instructions) model_parameters = sum([p.nelement() for p in self.model.parameters()]) print("Model parameters: {:.2E}".format(model_parameters)) self.best_prediction = 0.0
class Trainer: def __init__(self, data_train, data_valid, image_base_dir, instructions): """ :param data_train: :param data_valid: :param image_base_dir: :param instructions: """ self.image_base_dir = image_base_dir self.data_valid = data_valid self.instructions = instructions # specify model save dir self.model_name = instructions[STR.MODEL_NAME] # now = time.localtime() # start_time = "{}-{}-{}T{}:{}:{}".format(now.tm_year, now.tm_mon, now.tm_mday, now.tm_hour, now.tm_min, # now.tm_sec) experiment_folder_path = os.path.join(paths.MODELS_FOLDER_PATH, self.model_name) if os.path.exists(experiment_folder_path): Warning( "Experiment folder exists already. Files might be overwritten") os.makedirs(experiment_folder_path, exist_ok=True) # define saver and save instructions self.saver = Saver(folder_path=experiment_folder_path, instructions=instructions) self.saver.save_instructions() # define Tensorboard Summary self.writer = SummaryWriter(log_dir=experiment_folder_path) nn_input_size = instructions[STR.NN_INPUT_SIZE] state_dict_file_path = instructions.get(STR.STATE_DICT_FILE_PATH, None) self.colour_mapping = mapping.get_colour_mapping() # define transformers for training crops_per_image = instructions.get(STR.CROPS_PER_IMAGE, 10) apply_random_cropping = (STR.CROPS_PER_IMAGE in instructions.keys()) and \ (STR.IMAGES_PER_BATCH in instructions.keys()) print("{}applying random cropping".format( "" if apply_random_cropping else "_NOT_ ")) t = [Normalize()] if apply_random_cropping: t.append( RandomCrop(min_size=instructions.get(STR.CROP_SIZE_MIN, 400), max_size=instructions.get(STR.CROP_SIZE_MAX, 1000), crop_count=crops_per_image)) t += [ Resize(nn_input_size), Flip(p_vertical=0.2, p_horizontal=0.5), ToTensor() ] transformations_train = transforms.Compose(t) # define transformers for validation transformations_valid = transforms.Compose( [Normalize(), Resize(nn_input_size), ToTensor()]) # set up data loaders dataset_train = DictArrayDataSet(image_base_dir=image_base_dir, data=data_train, num_classes=len( self.colour_mapping.keys()), transformation=transformations_train) # define batch sizes self.batch_size = instructions[STR.BATCH_SIZE] if apply_random_cropping: self.data_loader_train = DataLoader( dataset=dataset_train, batch_size=instructions[STR.IMAGES_PER_BATCH], shuffle=True, collate_fn=custom_collate) else: self.data_loader_train = DataLoader(dataset=dataset_train, batch_size=self.batch_size, shuffle=True, collate_fn=custom_collate) dataset_valid = DictArrayDataSet(image_base_dir=image_base_dir, data=data_valid, num_classes=len( self.colour_mapping.keys()), transformation=transformations_valid) self.data_loader_valid = DataLoader(dataset=dataset_valid, batch_size=self.batch_size, shuffle=False, collate_fn=custom_collate) self.num_classes = dataset_train.num_classes() # define model print("Building model") self.model = DeepLab(num_classes=self.num_classes, backbone=instructions.get(STR.BACKBONE, "resnet"), output_stride=instructions.get( STR.DEEPLAB_OUTPUT_STRIDE, 16)) # load weights if state_dict_file_path is not None: print("loading state_dict from:") print(state_dict_file_path) load_state_dict(self.model, state_dict_file_path) learning_rate = instructions.get(STR.LEARNING_RATE, 1e-5) train_params = [{ 'params': self.model.get_1x_lr_params(), 'lr': learning_rate }, { 'params': self.model.get_10x_lr_params(), 'lr': learning_rate }] # choose gpu or cpu self.device = torch.device( "cuda:0" if torch.cuda.is_available() else "cpu") if instructions.get(STR.MULTI_GPU, False): if torch.cuda.device_count() > 1: print("Using ", torch.cuda.device_count(), " GPUs!") self.model = nn.DataParallel(self.model) self.model.to(self.device) # Define Optimizer self.optimizer = torch.optim.SGD(train_params, momentum=0.9, weight_decay=5e-4, nesterov=False) # calculate class weights if instructions.get(STR.CLASS_STATS_FILE_PATH, None): class_weights = calculate_class_weights( instructions[STR.CLASS_STATS_FILE_PATH], self.colour_mapping, modifier=instructions.get(STR.LOSS_WEIGHT_MODIFIER, 1.01)) class_weights = torch.from_numpy(class_weights.astype(np.float32)) else: class_weights = None self.criterion = SegmentationLosses( weight=class_weights, cuda=self.device.type != "cpu").build_loss() # Define Evaluator self.evaluator = Evaluator(self.num_classes) # Define lr scheduler self.scheduler = None if instructions.get(STR.USE_LR_SCHEDULER, True): self.scheduler = LR_Scheduler(mode="cos", base_lr=learning_rate, num_epochs=instructions[STR.EPOCHS], iters_per_epoch=len( self.data_loader_train)) # print information before training start print("-" * 60) print("instructions") pprint(instructions) model_parameters = sum([p.nelement() for p in self.model.parameters()]) print("Model parameters: {:.2E}".format(model_parameters)) self.best_prediction = 0.0 def train(self, epoch): self.model.train() train_loss = 0.0 # create a progress bar pbar = tqdm(self.data_loader_train) num_batches_train = len(self.data_loader_train) # go through each item in the training data for i, sample in enumerate(pbar): # set input and target nn_input = sample[STR.NN_INPUT].to(self.device) nn_target = sample[STR.NN_TARGET].to(self.device, dtype=torch.float) if self.scheduler: self.scheduler(self.optimizer, i, epoch, self.best_prediction) # run model output = self.model(nn_input) # calc losses loss = self.criterion(output, nn_target) # # save step losses # combined_loss_steps.append(float(loss)) # regression_loss_steps.append(float(regression_loss)) # classification_loss_steps.append(float(classification_loss)) train_loss += loss.item() pbar.set_description('Train loss: %.3f' % (train_loss / (i + 1))) self.writer.add_scalar('train/total_loss_iter', loss.item(), i + num_batches_train * epoch) # calculate gradient and update model weights loss.backward() # torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) self.optimizer.step() self.optimizer.zero_grad() self.writer.add_scalar('train/total_loss_epoch', train_loss, epoch) print("[Epoch: {}, num images/crops: {}]".format( epoch, num_batches_train * self.batch_size)) print("Loss: {:.2f}".format(train_loss)) def validation(self, epoch): self.model.eval() self.evaluator.reset() test_loss = 0.0 pbar = tqdm(self.data_loader_valid, desc='\r') num_batches_val = len(self.data_loader_valid) for i, sample in enumerate(pbar): # set input and target nn_input = sample[STR.NN_INPUT].to(self.device) nn_target = sample[STR.NN_TARGET].to(self.device, dtype=torch.float) with torch.no_grad(): output = self.model(nn_input) loss = self.criterion(output, nn_target) test_loss += loss.item() pbar.set_description('Test loss: %.3f' % (test_loss / (i + 1))) pred = output.data.cpu().numpy() pred = np.argmax(pred, axis=1) nn_target = nn_target.cpu().numpy() # Add batch sample into evaluator self.evaluator.add_batch(nn_target, pred) # Fast test during the training Acc = self.evaluator.Pixel_Accuracy() Acc_class = self.evaluator.Pixel_Accuracy_Class() mIoU = self.evaluator.Mean_Intersection_over_Union() FWIoU = self.evaluator.Frequency_Weighted_Intersection_over_Union() self.writer.add_scalar('val/total_loss_epoch', test_loss, epoch) self.writer.add_scalar('val/mIoU', mIoU, epoch) self.writer.add_scalar('val/Acc', Acc, epoch) self.writer.add_scalar('val/Acc_class', Acc_class, epoch) self.writer.add_scalar('val/fwIoU', FWIoU, epoch) print('Validation:') print("[Epoch: {}, num crops: {}]".format( epoch, num_batches_val * self.batch_size)) print( "Acc:{:.2f}, Acc_class:{:.2f}, mIoU:{:.2f}, fwIoU: {:.2f}".format( Acc, Acc_class, mIoU, FWIoU)) print("Loss: {:.2f}".format(test_loss)) new_pred = mIoU is_best = new_pred > self.best_prediction if is_best: self.best_prediction = new_pred self.saver.save_checkpoint(self.model, is_best, epoch)
def main(): here = osp.dirname(osp.abspath(__file__)) trainOpts = TrainOptions() args = trainOpts.get_arguments() now = datetime.datetime.now() args.out = osp.join( here, 'results', args.model + '_' + args.dataset + '_' + now.strftime('%Y%m%d__%H%M%S')) if not osp.isdir(args.out): os.makedirs(args.out) log_file = osp.join(args.out, args.model + '_' + args.dataset + '.log') mylog = open(log_file, 'w') checkpoint_dir = osp.join(args.out, 'checkpoints') os.makedirs(checkpoint_dir) os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu) cuda = torch.cuda.is_available() torch.manual_seed(1337) if cuda: torch.cuda.manual_seed(1337) # 1. dataset # MAIN_FOLDER = args.folder + 'Vaihingen/' # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' # train_ids = ['1', '3', '5', '21','23', '26', '7', '13', '17', '32', '37'] # val_ids =['11','15', '28', '30', '34'] # train_set = ISPRS_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER,cache=args.cache) # train_loader = torch.utils.data.DataLoader(train_set,batch_size=args.batch_size) # MAIN_FOLDER = args.folder + 'Potsdam_multiscale/' # DATA_FOLDER1 = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' # LABEL_FOLDER1 = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' # DATA_FOLDER2 = MAIN_FOLDER + '3_Ortho_IRRG_2/top_potsdam_{}_IRRG.tif' # LABEL_FOLDER2 = MAIN_FOLDER + '5_Labels_for_participants_2/top_potsdam_{}_label.tif' # train_ids=['2_10','3_10','3_11','3_12','4_11','4_12','5_10','5_12',\ # '6_8','6_9','6_10','6_11','6_12','7_7','7_9','7_11','7_12'] # val_ids=[ '2_11', '2_12', '4_10', '5_11', '6_7', '7_8', '7_10'] # target_set = ISPRS_dataset_multi(2,train_ids, DATA_FOLDER1, LABEL_FOLDER1,DATA_FOLDER2, LABEL_FOLDER2,cache=args.cache) # target_loader = torch.utils.data.DataLoader(target_set,batch_size=args.batch_size) # MAIN_FOLDER = args.folder + 'Potsdam/' # DATA_FOLDER = MAIN_FOLDER + '3_Ortho_IRRG/top_potsdam_{}_IRRG.tif' # LABEL_FOLDER = MAIN_FOLDER + '5_Labels_for_participants/top_potsdam_{}_label.tif' # ERODED_FOLDER = MAIN_FOLDER + '5_Labels_for_participants_no_Boundary/top_potsdam_{}_label_noBoundary.tif' # train_ids=['2_10','3_10','3_11','3_12','4_11','4_12','5_10','5_12',\ # '6_8','6_9','6_10','6_11','6_12','7_7','7_9','7_11','7_12'] # val_ids=[ '2_11', '2_12', '4_10', '5_11', '6_7', '7_8', '7_10'] # train_set = ISPRS_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER,cache=args.cache) # train_loader = torch.utils.data.DataLoader(train_set,batch_size=args.batch_size) # MAIN_FOLDER = args.folder + 'Vaihingen/' # DATA_FOLDER = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' # LABEL_FOLDER = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' # train_ids = ['1', '3', '5', '21','23', '26', '7', '13', '17', '32', '37'] # val_ids =['11','15', '28', '30', '34'] # train_set = ISPRS_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER,cache=args.cache) # train_loader = torch.utils.data.DataLoader(train_set,batch_size=args.batch_size) MAIN_FOLDER = args.folder + 'DeepGlobe/land-train_crop/' DATA_FOLDER = MAIN_FOLDER + '{}_sat.jpg' LABEL_FOLDER = MAIN_FOLDER + '{}_mask.png' all_files = sorted(glob(DATA_FOLDER.replace('{}', '*'))) all_ids = [f.split('/')[-1].split('_')[0] for f in all_files] train_ids = all_ids[:int(len(all_ids) / 3 * 2)] val_ids = all_ids[int(len(all_ids) / 3 * 2):] train_set = DeepGlobe_dataset(train_ids, DATA_FOLDER, LABEL_FOLDER) train_loader = torch.utils.data.DataLoader(train_set, batch_size=args.batch_size) MAIN_FOLDER = args.folder + 'ISPRS_dataset/Vaihingen/' DATA_FOLDER1 = MAIN_FOLDER + 'top/top_mosaic_09cm_area{}.tif' LABEL_FOLDER1 = MAIN_FOLDER + 'gts_for_participants/top_mosaic_09cm_area{}.tif' DATA_FOLDER2 = MAIN_FOLDER + 'resized_resolution5/top_mosaic_09cm_area{}.tif' LABEL_FOLDER2 = MAIN_FOLDER + 'gts_for_participants5/top_mosaic_09cm_area{}.tif' train_ids = ['1', '3', '5', '21', '23', '26', '7', '13', '17', '32', '37'] val_ids = ['11', '15', '28', '30', '34'] target_set = ISPRS_dataset_multi(5, train_ids, DATA_FOLDER1, LABEL_FOLDER1, DATA_FOLDER2, LABEL_FOLDER2, cache=args.cache) target_loader = torch.utils.data.DataLoader(target_set, batch_size=args.batch_size) # val_set = ISPRS_dataset(val_ids, DATA_FOLDER1, LABEL_FOLDER1,cache=args.cache) # val_loader = torch.utils.data.DataLoader(val_set,batch_size=args.batch_size) LABELS = [ "roads", "buildings", "low veg.", "trees", "cars", "clutter", "unknown" ] # Label names N_CLASS = len(LABELS) # Number of classes # 2. model if args.backbone == 'resnet': model = DeepLab(num_classes=N_CLASS, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) elif args.backbone == 'resnet_multiscale': model = DeepLabCA(num_classes=N_CLASS, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) else: print('backbone not exists!') train_params = [{ 'params': model.get_1x_lr_params(), 'lr': args.lr }, { 'params': model.get_10x_lr_params(), 'lr': args.lr * 10 }] start_epoch = 0 start_iteration = 0 # 3. optimizer lr = args.lr # optimizer = torch.optim.SGD(model.parameters(), lr=args.lr, # momentum=args.momentum, weight_decay=args.weight_decay) netD_domain = FCDiscriminator(num_classes=N_CLASS) netD_scale = FCDiscriminator(num_classes=N_CLASS) optim_netG = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optim_netD_domain = optim.Adam(netD_domain.parameters(), lr=args.lr_D, betas=(0.9, 0.99)) optim_netD_scale = optim.Adam(netD_scale.parameters(), lr=args.lr_D, betas=(0.9, 0.99)) if cuda: model, netD_domain, netD_scale = model.cuda(), netD_domain.cuda( ), netD_scale.cuda() if args.resume: checkpoint = torch.load(args.resume) model.load_state_dict(checkpoint) bce_loss = torch.nn.BCEWithLogitsLoss() # 4. training iter_ = 0 no_optim = 0 val_best_loss = float('Inf') factor = 10 max_iter = 50000 trainloader_iter = enumerate(train_loader) targetloader_iter = enumerate(target_loader) source_label = 0 target_label = 1 source_scale_label = 0 target_scale_label = 1 train_loss = [] train_acc = [] target_acc_s1 = [] target_acc_s2 = [] while iter_ < max_iter: optim_netG.zero_grad() adjust_learning_rate(optim_netG, iter_, args) optim_netD_domain.zero_grad() optim_netD_scale.zero_grad() adjust_learning_rate_D(optim_netD_domain, iter_, args) adjust_learning_rate_D(optim_netD_scale, iter_, args) if iter_ % 1000 == 0: train_loss = [] train_acc = [] target_acc_s1 = [] target_acc_s2 = [] for param in netD_domain.parameters(): param.requires_grad = False for param in netD_scale.parameters(): param.requires_grad = False _, batch = trainloader_iter.__next__() im_s, label_s = batch _, batch = targetloader_iter.__next__() im_t_s1, label_t_s1, im_t_s2, label_t_s2 = batch if cuda: im_s, label_s = im_s.cuda(), label_s.cuda() im_t_s1, label_t_s1, im_t_s2, label_t_s2 = im_t_s1.cuda( ), label_t_s1.cuda(), im_t_s2.cuda(), label_t_s2.cuda() ############ #TRAIN NETG# ############ #train with source #optimize segmentation network with source data pred_seg = model(im_s) seg_loss = cross_entropy2d(pred_seg, label_s) seg_loss /= len(im_s) loss_data = seg_loss.data.item() if np.isnan(loss_data): # continue raise ValueError('loss is nan while training') seg_loss.backward() # import pdb # pdb.set_trace() pred = np.argmax(pred_seg.data.cpu().numpy()[0], axis=0) gt = label_s.data.cpu().numpy()[0] train_acc.append(accuracy(pred, gt)) train_loss.append(loss_data) #train with target pred_s1 = model(im_t_s1) pred = np.argmax(pred_s1.data.cpu().numpy()[0], axis=0) gt = label_t_s1.data.cpu().numpy()[0] target_acc_s1.append(accuracy(pred, gt)) pred_s2 = model(im_t_s2) pred = np.argmax(pred_s2.data.cpu().numpy()[0], axis=0) gt = label_t_s2.data.cpu().numpy()[0] target_acc_s2.append(accuracy(pred, gt)) pred_d = netD_domain(F.softmax(pred_s1)) pred_s = netD_scale(F.softmax(pred_s2)) loss_adv_domain = bce_loss( pred_d, Variable( torch.FloatTensor( pred_d.data.size()).fill_(source_label)).cuda()) loss_adv_scale = bce_loss( pred_s, Variable( torch.FloatTensor( pred_s.data.size()).fill_(source_scale_label)).cuda()) loss = args.lambda_adv_domain * loss_adv_domain + args.lambda_adv_scale * loss_adv_scale loss /= len(im_t_s1) loss.backward() ############ #TRAIN NETD# ############ for param in netD_domain.parameters(): param.requires_grad = True for param in netD_scale.parameters(): param.requires_grad = True #train with source domain and source scale pred_seg, pred_s1 = pred_seg.detach(), pred_s1.detach() pred_d = netD_domain(F.softmax(pred_seg)) # pred_s=netD_scale(F.softmax(pred_seg)) pred_s = netD_scale(F.softmax(pred_s1)) loss_D_domain = bce_loss( pred_d, Variable( torch.FloatTensor( pred_d.data.size()).fill_(source_label)).cuda()) loss_D_scale = bce_loss( pred_s, Variable( torch.FloatTensor( pred_s.data.size()).fill_(source_scale_label)).cuda()) loss_D_domain = loss_D_domain / len(im_s) / 2 loss_D_scale = loss_D_scale / len(im_s) / 2 loss_D_domain.backward() loss_D_scale.backward() #train with target domain and target scale pred_s1, pred_s2 = pred_s1.detach(), pred_s2.detach() pred_d = netD_domain(F.softmax(pred_s1)) pred_s = netD_scale(F.softmax(pred_s2)) loss_D_domain = bce_loss( pred_d, Variable( torch.FloatTensor( pred_d.data.size()).fill_(target_label)).cuda()) loss_D_scale = bce_loss( pred_s, Variable( torch.FloatTensor( pred_s.data.size()).fill_(target_scale_label)).cuda()) loss_D_domain = loss_D_domain / len(im_s) / 2 loss_D_scale = loss_D_scale / len(im_s) / 2 loss_D_domain.backward() loss_D_scale.backward() optim_netG.step() optim_netD_domain.step() optim_netD_scale.step() if iter_ % 100 == 0: print( 'Train [{}/{} Source loss:{:.6f} acc:{:.4f} % Target s1 acc:{:4f}% Target s2 acc:{:4f}%]' .format(iter_, max_iter, sum(train_loss) / len(train_loss), sum(train_acc) / len(train_acc), sum(target_acc_s1) / len(target_acc_s1), sum(target_acc_s2) / len(target_acc_s2))) print( 'Train [{}/{} Source loss:{:.6f} acc:{:.4f} % Target s1 acc:{:4f}% Target s2 acc:{:4f}%]' .format(iter_, max_iter, sum(train_loss) / len(train_loss), sum(train_acc) / len(train_acc), sum(target_acc_s1) / len(target_acc_s1), sum(target_acc_s2) / len(target_acc_s2)), file=mylog) if iter_ % 1000 == 0: print('saving checkpoint.....') torch.save(model.state_dict(), osp.join(checkpoint_dir, 'iter{}.pth'.format(iter_))) iter_ += 1
choices=['resnet', 'xception', 'drn', 'mobilenet'], help='backbone name (default: resnet)') parser.add_argument( '--pretrained', type=str, default='/Users/yulian/Downloads/mixup_model_best.pth.tar', help='pretrained model') parser.add_argument('--color', type=str, default='purple', choices=['purple', 'green', 'blue', 'red'], help='Color your hair (default: purple)') args = parser.parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") net = DeepLab(backbone=args.backbone, output_stride=16, num_classes=2, sync_bn=False).to(device) net.load_state_dict( torch.load(args.pretrained, map_location=device)['state_dict']) net.eval() cam = cv2.VideoCapture(0) if not cam.isOpened(): raise Exception("webcam is not detected") while (True): # ret : frame capture(boolean) # frame : Capture frame ret, image = cam.read() if (ret): image, mask = get_image_mask(image, net)
cp_seg_name = '/_checkpoint/cp_seg.pth' cp_act_name = '/_checkpoint/cp_act.pth' save_seg_name = '/model.pth' save_act_name = '/action.pth' # segmentation import torch from modeling.deeplab import DeepLab n_class = 10 try: model = DeepLab(num_classes=n_class, backbone='xception', output_stride=16, sync_bn=bool(None), freeze_bn=bool(False)) model = model.cuda() checkpoint = torch.load(now_dir + cp_seg_name) model.load_state_dict(checkpoint['state_dict']) torch.save(model, now_dir + target_dir + save_seg_name) print('segmentation model - OK!') except: print('segmentation model - Failed!') # action import torchvision.models as models import torch.nn as nn
def main(args): config = ConfigParser(args) cfg = config.config logger = get_logger(config.log_dir, "train") train_dataset = MRIBrainSegmentation(root_folder=cfg['root_folder'], image_label=cfg['train_data'], is_train=True, ignore_label=0, input_size=cfg['input_size']) vali_dataset = MRIBrainSegmentation(root_folder=cfg['root_folder'], image_label=cfg['validation_data'], is_train=False, ignore_label=0, input_size=cfg['input_size']) train_loader = torch.utils.data.DataLoader( train_dataset, batch_size=cfg["train_batch_size"], shuffle=True, num_workers=cfg["workers"], drop_last=True) vali_loader = torch.utils.data.DataLoader( vali_dataset, batch_size=cfg["vali_batch_size"], shuffle=False, num_workers=cfg["workers"], drop_last=False) if cfg['net_name'] == "deeplab": model = DeepLab(num_classes=1, backbone=cfg['backbone'], output_stride=cfg['output_stride'], sync_bn=cfg['sync_bn'], freeze_bn=cfg['freeze_bn']) else: model = Unet(in_channels=3, out_channels=1, init_features=32) criterion = getattr(loss, 'dice_loss') optimizer = optim.SGD(model.parameters(), lr=cfg["lr"], momentum=0.9, weight_decay=cfg["weight_decay"]) metrics_name = [] scheduler = Poly_Scheduler(base_lr=cfg['lr'], num_epochs=config['epoch'], iters_each_epoch=len(train_loader)) trainer = Trainer(model=model, criterion=criterion, optimizer=optimizer, train_loader=train_loader, nb_epochs=config['epoch'], valid_loader=vali_loader, lr_scheduler=scheduler, logger=logger, log_dir=config.save_dir, metrics_name=metrics_name, resume=config['resume'], save_dir=config.save_dir, device="cuda:0", monitor="max iou_class_1", early_stop=-1) trainer.train()
def __init__(self, args): self.args = args # Define Saver self.saver = Saver(args) self.saver.save_experiment_config() # Define Tensorboard Summary self.summary = TensorboardSummary(self.saver.experiment_dir) self.writer = self.summary.create_summary() # Define Dataloader kwargs = {'num_workers': args.workers, 'pin_memory': True} self.train_loader, self.val_loader, self.test_loader, self.nclass = make_data_loader( args, **kwargs) # Define network model = DeepLab(num_classes=self.nclass, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) # init D model_D = FCDiscriminator(num_classes=19) train_params = [{ 'params': model.get_1x_lr_params(), 'lr': args.lr }, { 'params': model.get_10x_lr_params(), 'lr': args.lr * 10 }] # Define Optimizer optimizer = torch.optim.SGD(train_params, momentum=args.momentum, weight_decay=args.weight_decay, nesterov=args.nesterov) optimizer_D = torch.optim.Adam(model_D.parameters(), lr=1e-4, betas=(0.9, 0.99)) # Define Criterion # whether to use class balanced weights if args.use_balanced_weights: classes_weights_path = 'dataloders\\datasets\\' + args.dataset + '_classes_weights.npy' if os.path.isfile(classes_weights_path): weight = np.load(classes_weights_path) else: weight = calculate_weigths_labels(args.dataset, self.train_loader, self.nclass) weight = torch.from_numpy(weight.astype(np.float32)) else: weight = None self.criterion = SegmentationLosses( weight=weight, cuda=args.cuda).build_loss(mode=args.loss_type) self.bce_loss = torch.nn.BCEWithLogitsLoss() self.model, self.optimizer = model, optimizer self.model_D, self.optimizer_D = model_D, optimizer_D # Define Evaluator self.evaluator = Evaluator(self.nclass) # Define lr scheduler self.scheduler = LR_Scheduler(args.lr_scheduler, args.lr, args.epochs, len(self.train_loader)) # Using cuda if args.cuda: self.model = torch.nn.DataParallel(self.model, device_ids=self.args.gpu_ids) self.model_D = torch.nn.DataParallel(self.model_D, device_ids=self.args.gpu_ids) patch_replication_callback(self.model) patch_replication_callback(self.model_D) self.model = self.model.cuda() self.model_D = self.model_D.cuda() # Resuming checkpoint self.best_pred = 0.0 if args.resume is not None: if not os.path.isfile(args.resume): raise RuntimeError("=> no checkpoint found at '{}'".format( args.resume)) checkpoint = torch.load(args.resume) args.start_epoch = checkpoint['epoch'] if args.cuda: self.model.module.load_state_dict(checkpoint['state_dict']) else: self.model.load_state_dict(checkpoint['state_dict']) if not args.ft: self.optimizer.load_state_dict(checkpoint['optimizer']) self.best_pred = checkpoint['best_pred'] print("=> loaded checkpoint '{}' (epoch {})".format( args.resume, checkpoint['epoch'])) # Clear start epoch if fine-tuning if args.ft: args.start_epoch = 0
class DeeplabRos: def __init__(self): #GPU assignment os.environ["CUDA_VISIBLE_DEVICES"] = "0" self.device = torch.device( 'cuda:0' if torch.cuda.is_available() else 'cpu') #Load checkpoint self.checkpoint = torch.load( os.path.join("./src/deeplab_ros/data/model_best.pth.tar")) #Load Model self.model = DeepLab(num_classes=4, backbone='mobilenet', output_stride=16, sync_bn=True, freeze_bn=False) self.model.load_state_dict(self.checkpoint['state_dict']) self.model = self.model.to(self.device) #ROS init self.bridge = CvBridge() self.image_sub = rospy.Subscriber("/cam2/pylon_camera_node/image_raw", ImageMsg, self.callback, queue_size=1, buff_size=2**24) self.image_pub = rospy.Publisher("segmentation_image", ImageMsg, queue_size=1) def callback(self, data): cv_image = self.bridge.imgmsg_to_cv2(data, "bgr8") start_time = time.time() self.model.eval() torch.set_grad_enabled(False) tfms = transforms.Compose([ transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) inputs = tfms(cv_image).to(self.device) output = self.model(inputs.unsqueeze(0)).squeeze().cpu().numpy() pred = np.argmax(output, axis=0) pred_img = self.label_to_color_image(pred) msg = self.bridge.cv2_to_imgmsg(pred_img, "bgr8") inference_time = time.time() - start_time print("inference time: ", inference_time) self.image_pub.publish(msg) def label_to_color_image(self, pred, class_num=4): label_colors = np.array([(0, 0, 0), (0, 0, 128), (0, 128, 0), (128, 0, 0)]) #bgr # Unlabeled, Building, Lane-marking, Fence r = np.zeros_like(pred).astype(np.uint8) g = np.zeros_like(pred).astype(np.uint8) b = np.zeros_like(pred).astype(np.uint8) for i in range(0, class_num): idx = pred == i r[idx] = label_colors[i, 0] g[idx] = label_colors[i, 1] b[idx] = label_colors[i, 2] rgb = np.stack([r, g, b], axis=2) return rgb
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'DeepLab': #model = DeeplabMulti(num_classes=args.num_classes) #model = Res_Deeplab(num_classes=args.num_classes) model = DeepLab(backbone='resnet', output_stride=16) ''' if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) #restore(model, saved_state_dict) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not i_parts[0] == 'layer4' and not i_parts[0] == 'fc': #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] new_params[i] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) ''' else: raise NotImplementedError model.train() model.to(device) cudnn.benchmark = True # init D model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) # if args.restore_from_D[:4] == 'http': # saved_state_dict = model_zoo.load_url(args.restore_from_D) # else: # saved_state_dict = torch.load(args.restore_from_D) # ### for running different versions of pytorch # model_dict = model_D1.state_dict() # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict} # model_dict.update(saved_state_dict) # model_D1.load_state_dict(saved_state_dict) model_D1.train() model_D1.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_loader = data_loader(args) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) optimizer_D1.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() seg_loss = torch.nn.CrossEntropyLoss() interp = nn.Upsample(size=(416, 416), mode='bilinear', align_corners=True) #interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training # set up tensorboard if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) count = args.start_count # 迭代次数 for dat in train_loader: if count > args.num_steps: break loss_seg_value1_anchor = 0 loss_adv_target_value1 = 0 loss_D_value1 = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, count) optimizer_D1.zero_grad() adjust_learning_rate_D(optimizer_D1, count) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D for param in model_D1.parameters(): param.requires_grad = False # 相当于group=0时,训练样本对应的类有15类为[0,1,2,3,4,5,6,7,8,9,10,....],验证集有5类, # 现在从训练集类中随机选择两类,然后从其中一类中选择两张图片,对应为基准图片和正样本图片, # 两者属于同一类,接着从另一类中选择一张图片作为负样本,属于不同类。其中基准图片对应的是查询集图片 ############################# anchor_img, anchor_mask, pos_img, pos_mask, neg_img, neg_mask = dat # 返回的是基准图片以及mask,正样本以及mask(和基准图片属于同一类),负样本以及mask(和基准图片属于不同类) anchor_img, anchor_mask, pos_img, pos_mask, \ = anchor_img.cuda(), anchor_mask.cuda(), pos_img.cuda(), pos_mask.cuda() # [1, 3, 386, 500],[1, 386, 500],[1, 3, 374, 500],[1, 374, 500] anchor_mask = torch.unsqueeze(anchor_mask, dim=1) # [1, 1, 386, 500] pos_mask = torch.unsqueeze(pos_mask, dim=1) # [1,1, 374, 500] samples = torch.cat([pos_img, anchor_img], 0) pred = model(samples, pos_mask) ##[2, 2, 53, 53],#[2, 2, 53, 53] pred = interp(pred) loss_seg1_anchor = seg_loss( pred, anchor_mask.squeeze().unsqueeze(0).long()) D_out1 = model_D1(F.softmax(pred)) loss_adv_target1 = bce_loss( D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to( device)) # 相当于将源域的标签设置为1,然后判断判别网络得到的目标预测与源域对应的损失 ''' s = torch.stack([s, 1-s]) loss_s = seg_loss() ''' loss = loss_seg1_anchor + args.lambda_adv_target1 * loss_adv_target1 # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1_anchor += loss_seg1_anchor.item() / args.iter_size loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size # train D# bring back requires_grad for param in model_D1.parameters(): param.requires_grad = True # train with anchor pred_target1 = pred.detach() D_out1 = model_D1(F.softmax(pred_target1)) loss_D1 = bce_loss( D_out1, torch.FloatTensor(D_out1.data.size()).fill_(0).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D1.backward() loss_D_value1 += loss_D1.item() # train with GT anchor_gt = Variable(one_hot(anchor_mask)).cuda() D_out1 = model_D1(anchor_gt) loss_D1 = bce_loss( D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to(device)) loss_D1 = loss_D1 / args.iter_size / 2 loss_D1.backward() loss_D_value1 += loss_D1.item() optimizer.step() optimizer_D1.step() count = count + 1 if args.tensorboard: scalar_info = { 'loss_seg1_anchor': loss_seg_value1_anchor, 'loss_adv_target1': loss_adv_target_value1, 'loss_D1': loss_D_value1, } if count % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, count) # print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_adv1 = {3:.3f}, loss_D1 = {4:.3f}' .format(count, args.num_steps, loss_seg_value1_anchor, loss_adv_target_value1, loss_D_value1)) if count >= args.num_steps_stop - 1: print('save model ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(args.num_steps_stop) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(args.num_steps_stop) + '_D1.pth')) break if count % args.save_pred_every == 0 and count != 0: print('taking snapshot ...') torch.save( model.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(count) + '.pth')) torch.save( model_D1.state_dict(), osp.join(args.snapshot_dir, 'voc2012_' + str(count) + '_D1.pth')) if args.tensorboard: writer.close()
from modeling.deeplab import DeepLab import kornia from PIL import Image import torch import torchvision.transforms.functional as TF import numpy as np # this example only uses 1 image, so cpu is fine device = torch.device("cpu") # load pre-trained weights, set network to inference mode network = DeepLab(num_classes=18) network.load_state_dict( torch.load("segmentation-model/epoch-14", map_location="cpu")) network.eval() network.to(device) # load example image. the image is resized because DeepLab uses # a lot of dilated convolutions and doesn't work very well for # low resolution images. image = Image.open("nate.jpg") scaled_image = image.resize((418, 512), resample=Image.LANCZOS) image_tensor = TF.to_tensor(scaled_image) # send the input through the network. unsqueeze is used to # add a batch dimension, because torch always expects a batch # but in this case it's just one image # I then use Kornia to resize the mask back to 218x178 then # squeeze to remove the batch channel again (kornia also # always expects a batch dimension) with torch.no_grad():
def main(): """Create the model and start the training.""" device = torch.device("cuda" if not args.cpu else "cpu") w, h = map(int, args.input_size.split(',')) input_size = (w, h) w, h = map(int, args.input_size_target.split(',')) input_size_target = (w, h) cudnn.enabled = True # Create network if args.model == 'DeepLab': # model = DeeplabMulti(num_classes=args.num_classes) # model = Res_Deeplab(num_classes=args.num_classes) model = DeepLab(backbone='resnet', output_stride=8) ''' if args.restore_from[:4] == 'http' : saved_state_dict = model_zoo.load_url(args.restore_from) else: saved_state_dict = torch.load(args.restore_from) #restore(model, saved_state_dict) new_params = model.state_dict().copy() for i in saved_state_dict: # Scale.layer5.conv2d_list.3.weight i_parts = i.split('.') # print i_parts if not i_parts[0] == 'layer4' and not i_parts[0] == 'fc': #new_params['.'.join(i_parts[1:])] = saved_state_dict[i] new_params[i] = saved_state_dict[i] # print i_parts model.load_state_dict(new_params) ''' # saved_state_dict = torch.load(args.restore_from) # ### for running different versions of pytorch # model_dict = model.state_dict() # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict} # model_dict.update(saved_state_dict) # model.load_state_dict(saved_state_dict) model.train() model.to(device) cudnn.benchmark = True # init D #model_D1 = FCDiscriminator(num_classes=args.num_classes).to(device) # saved_state_dict = torch.load(args.restore_from_D) # ### for running different versions of pytorch # model_dict = model_D1.state_dict() # saved_state_dict = {k: v for k, v in saved_state_dict.items() if k in model_dict} # model_dict.update(saved_state_dict) # model_D1.load_state_dict(saved_state_dict) # model_D1.train() # model_D1.to(device) if not os.path.exists(args.snapshot_dir): os.makedirs(args.snapshot_dir) train_loader = data_loader(args) # implement model.optim_parameters(args) to handle different models' lr setting optimizer = optim.SGD(model.optim_parameters(args), lr=args.learning_rate, momentum=args.momentum, weight_decay=args.weight_decay) optimizer.zero_grad() # optimizer_D1 = optim.Adam(model_D1.parameters(), lr=args.learning_rate_D, betas=(0.9, 0.99)) # optimizer_D1.zero_grad() if args.gan == 'Vanilla': bce_loss = torch.nn.BCEWithLogitsLoss() elif args.gan == 'LS': bce_loss = torch.nn.MSELoss() #seg_loss = FocalLoss2d(gamma=2.0, weight=0.75).to(device)#alpha是用来衡量样本的正负样本不平衡的 #seg_loss = FocalLoss2d(gamma=2.0, weight=0.75).to(device) # seg_loss = FocalLoss(alpha=0.75, logits=True) #seg_loss = FocalLoss(class_num=2).to(device) seg_loss = torch.nn.CrossEntropyLoss() affinity_loss = AffinityFieldLoss(kl_margin=3.) R_loss = torch.nn.MSELoss() # interp_target = nn.Upsample(size=(input_size_target[1], input_size_target[0]), mode='bilinear', align_corners=True) # labels for adversarial training # set up tensor board if args.tensorboard: if not os.path.exists(args.log_dir): os.makedirs(args.log_dir) writer = SummaryWriter(args.log_dir) count = args.start_count # 迭代次数 for dat in train_loader: if count > args.num_steps: break loss_seg_value1_anchor = 0 loss_adv_target_value1 = 0 loss_affinity_value1_anchor = 0 loss_D_value1 = 0 loss_R_values = 0 loss_A_values = 0 optimizer.zero_grad() adjust_learning_rate(optimizer, count) # optimizer_D1.zero_grad() # adjust_learning_rate_D(optimizer_D1, count) for sub_i in range(args.iter_size): # train G # don't accumulate grads in D # for param in model_D1.parameters(): # param.requires_grad = False # 相当于group=0时,训练样本对应的类有15类为[0,1,2,3,4,5,6,7,8,9,10,....],验证集有5类, # 现在从训练集类中随机选择两类,然后从其中一类中选择两张图片,对应为基准图片和正样本图片, # 两者属于同一类,接着从另一类中选择一张图片作为负样本,属于不同类。其中基准图片对应的是查询集图片 ############################# anchor_img, anchor_mask, pos_img, pos_mask, neg_img, neg_mask = dat # 返回的是基准图片以及mask,正样本以及mask(和基准图片属于同一类),负样本以及mask(和基准图片属于不同类) anchor_img, anchor_mask, pos_img, pos_mask, \ = anchor_img.cuda(), anchor_mask.cuda(), pos_img.cuda(), pos_mask.cuda() # [1, 3, 386, 500],[1, 386, 500],[1, 3, 374, 500],[1, 374, 500] anchor_mask = torch.unsqueeze(anchor_mask, dim=1) # [1, 1, 386, 500] pos_mask = torch.unsqueeze(pos_mask, dim=1) # [1,1, 374, 500] samples = torch.cat([pos_img, anchor_img], 0) if count == 5134: import matplotlib.pyplot as plt plt.imshow(pos_img[0][0].cpu().detach().numpy()) plt.show() plt.imshow(pos_mask[0][0].cpu().detach().numpy()) plt.show() pred = model(samples, pos_mask) ##[2, 2, 53, 53],#[2, 2, 53, 53]#[1,2704,2704]#[1,52,52] _, _, w1, h1 = pred.size() _, _, mask_w, mask_h = anchor_mask.size() ####################分割loss和对抗loss############################################# pred = F.interpolate(pred, [mask_w, mask_h], mode='bilinear', align_corners=False) # loss_seg1_anchor = seg_loss(pred.squeeze(), anchor_mask.squeeze())###针对BCELOSS loss_seg1_anchor = seg_loss(pred, anchor_mask.squeeze().unsqueeze(0).long()) ##SOFTMAX loss_affinity = affinity_loss(pred, anchor_mask.squeeze().unsqueeze(0).long()) # D_out1 = model_D1(F.softmax(pred)) # loss_adv_target1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to( # device)) # 相当于将源域的标签设置为1,然后判断判别网络得到的目标预测与源域对应的损失 #########################关系矩阵损失,R################################################# # G_q = F.interpolate(anchor_mask, [w1, h1], mode='bilinear', align_corners=True) # G_s = F.interpolate(pos_mask, [w1, h1], mode='bilinear', align_corners=True) # R_gt = G_q.reshape(w1 * h1, -1) * G_s.reshape(-1, w1 * h1) # loss_R1 = R_loss(R1.squeeze(), R_gt) # loss_R2 = R_loss(R2.squeeze(), R_gt) ##########################注意力矩阵A loss#################################################### ''' A1 = torch.cat([1 - A1, A1], 0) A1 = interp(A1.unsqueeze(0)) A2 = torch.cat([1 - A2, A2], 0) A2 = interp(A2.unsqueeze(0)) loss_A1 = seg_loss(A1, anchor_mask.squeeze().unsqueeze(0).long()) loss_A2 = seg_loss(A2, anchor_mask.squeeze().unsqueeze(0).long()) ''' #######################总的loss############################################# # loss = loss_seg1_anchor + args.lambda_adv_target1 * loss_adv_target1 + 0.3 * loss_R1 +0.3 * loss_R2 + 0.2 * loss_A1 + 0.2 * loss_A2 loss = loss_seg1_anchor + loss_affinity # proper normalization loss = loss / args.iter_size loss.backward() loss_seg_value1_anchor += loss_seg1_anchor.item() / args.iter_size loss_affinity_value1_anchor += loss_affinity.item() / args.iter_size # loss_adv_target_value1 += loss_adv_target1.item() / args.iter_size # loss_R_values += loss_R1.item() / args.iter_size # loss_R_values += loss_R2.item() / args.iter_size # loss_A_values += loss_A1.item() / args.iter_size # loss_A_values += loss_A2.item() / args.iter_size # train D# bring back requires_grad # for param in model_D1.parameters(): # param.requires_grad = True # # # train with anchor # pred_target1 = pred.detach() # D_out1 = model_D1(F.softmax(pred_target1)) # loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(0).to(device)) # loss_D1 = loss_D1 / args.iter_size / 2 # loss_D1.backward() # loss_D_value1 += loss_D1.item() # # # train with GT # anchor_gt = Variable(one_hot(anchor_mask)).cuda() # D_out1 = model_D1(anchor_gt) # loss_D1 = bce_loss(D_out1, torch.FloatTensor(D_out1.data.size()).fill_(1).to(device)) # loss_D1 = loss_D1 / args.iter_size / 2 # loss_D1.backward() # loss_D_value1 += loss_D1.item() optimizer.step() # optimizer_D1.step() count = count + 1 if args.tensorboard: scalar_info = { 'loss_seg1_anchor': loss_seg_value1_anchor, 'loss_affinity_anchor': loss_affinity_value1_anchor, 'loss_adv_target1': loss_adv_target_value1, 'loss_D1': loss_D_value1, 'loss_R': loss_R_values, 'loss_A': loss_A_values } if count % 10 == 0: for key, val in scalar_info.items(): writer.add_scalar(key, val, count) print('exp = {}'.format(args.snapshot_dir)) print( 'iter = {0:8d}/{1:8d}, loss_seg1 = {2:.3f}, loss_affinity = {3:.3f}, loss_D1 = {4:.3f}, loss_R = {5:.3f} loss_A = {6:.3f}'.format( count, args.num_steps, loss_seg_value1_anchor, loss_affinity_value1_anchor, loss_D_value1, loss_R_values, loss_A_values)) if count >= args.num_steps_stop - 1: print('save model ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'voc2012_1_' + str(args.num_steps_stop) + '.pth')) # torch.save(model_D1.state_dict(), # osp.join(args.snapshot_dir, 'voc2012_1_' + str(args.num_steps_stop) + '_D1.pth')) break if count % args.save_pred_every == 0 and count != 0: print('taking snapshot ...') torch.save(model.state_dict(), osp.join(args.snapshot_dir, 'voc2012_1_' + str(count) + '.pth')) # torch.save(model_D1.state_dict(), osp.join(args.snapshot_dir, 'voc2012_1_' + str(count) + '_D1.pth')) if args.tensorboard: writer.close()
def main(): parser = argparse.ArgumentParser( description="PyTorch DeeplabV3Plus Training") parser.add_argument('--backbone', type=str, default='resnet', choices=['resnet', 'xception', 'drn', 'mobilenet'], help='backbone name (default: resnet)') parser.add_argument('--out_stride', type=int, default=16, help='network output stride (default: 8)') parser.add_argument('--rip_mode', type=str, default='patches-level2') parser.add_argument('--use_sbd', action='store_true', default=True, help='whether to use SBD dataset (default: True)') parser.add_argument('--workers', type=int, default=8, metavar='N', help='dataloader threads') parser.add_argument('--base_size', type=int, default=800, help='base image size') parser.add_argument('--crop_size', type=int, default=800, help='crop image size') parser.add_argument('--sync_bn', type=bool, default=None, help='whether to use sync bn (default: auto)') parser.add_argument( '--freeze_bn', type=bool, default=False, help='whether to freeze bn parameters (default: False)') # cuda, seed and logging parser.add_argument('--gpus', type=int, default=1, help='how many gpus to use (default=1)') parser.add_argument('--seed', type=int, default=123, metavar='S', help='random seed (default: 1)') # checking point parser.add_argument('--resume', type=str, default=None, help='put the path to resuming file if needed') parser.add_argument('--checkname', type=str, default=None, help='set the checkpoint name') parser.add_argument('--exp_root', type=str, default='') args = parser.parse_args() args.device, args.cuda = get_available_device(args.gpus) nclass = 3 model = DeepLab(num_classes=nclass, backbone=args.backbone, output_stride=args.out_stride, sync_bn=args.sync_bn, freeze_bn=args.freeze_bn) args.checkname = '/data2/data2/zewei/exp/RipData/DeepLabV3/patches/level2/CV5-1/model_best.pth.tar' ckpt = torch.load(args.checkname) model.load_state_dict(ckpt['state_dict']) model.eval() model = model.to(args.device) img_files = ['doc/tests/img_cv.png'] out_file = 'doc/tests/img_seg.png' transforms = Compose([ ToTensor(), Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)) ]) color_map = get_rip_labels() img_cv = cv2.imread(img_files[0]) pred = process_single_large_image(model, img_cv) mask = gen_mask(pred, nclass, color_map) out_img = composite_image(img_cv, mask, alpha=0.2) save_image(mask, out_file.split('.')[0] + f'_mask.png') save_image(out_img, out_file.split('.')[0] + f'_com.png') print(f'saved image {out_file}') with torch.no_grad(): for img_file in img_files: name, ext = img_file.split('.') img_cv = cv2.imread(img_file) patches = decompose_image(img_cv, None, (800, 800), (300, 700)) print(f'Decompose input image into {len(patches)} patches.') for i, patch in patches.items(): img = transforms(patch.image) img = torch.stack([img], dim=0).cuda() output = model(img) output = output.data.cpu().numpy() pred = np.argmax(output, axis=1) expanded_pred = torch.zeros() # out_img = output[0].cpu().permute((1, 2, 0)).numpy() # out_img = (out_img * 255).astype(np.uint8) mask = gen_mask(pred[0], nclass, color_map) out_img = composite_image(patch.image, mask, alpha=0.2) save_image(mask, name + f'_patch{i:02d}_seg.' + ext) save_image(out_img, name + f'_patch{i:02d}_seg_img.' + ext) print(f'saved image {out_file}')
required=True) parser.add_argument('--output', '-o', metavar='output_path', help='Output image', required=True) args = parser.parse_args() dataset = "fashion_clothes" path = "./bestmodels/deep_clothes/checkpoint.pth.tar" nclass = 7 #Initialize the DeeplabV3+ model model = DeepLab(num_classes=nclass, output_stride=8) #run model on CPU model.cpu() torch.set_num_threads(8) #error checking if not os.path.isfile(path): raise RuntimeError("no model found at'{}'".format(path)) if not os.path.isfile(args.input): raise RuntimeError("no image found at'{}'".format(input)) if os.path.exists(args.output): raise RuntimeError("Existed file or dir found at'{}'".format( args.output))