def init_model(self, ): if (self.model_name == "PSPNet"): self.model, self.model_out_layer_names = PSPNet( backbone_name=self.backbone, input_shape=self.input_shape, classes=self.n_classes, encoder_weights=self.pretrained_encoder_weights, encoder_freeze=self.transfer_learning, training=self.training) try: # The model weights (that are considered the best) are loaded into the model. self.model.load_weights( '/home/essys/projects/segmentation/checkpoints/' + self.train_id + "/") except: print('could not find saved model') if (self.model_name == "Bisenet_V2"): self.model = BisenetV2Model( train_op=optimizers.Adam(self.lr), input_shape=self.input_shape, classes=self.n_classes, batch_size=self.batch_size, class_weights=self.class_weights ) #BisenetV2(input_shape=self.input_shape,classes=self.n_classes) try: checkpoint_dir = '/home/essys/projects/segmentation/checkpoints/' + self.train_id + "/" latest = tf.train.latest_checkpoint(checkpoint_dir) print(latest + " is found model\n\n") # The model weights (that are considered the best) are loaded into the model. self.model.model.load_weights(latest) except: print('could not find saved model')
def test_one_image(args, dt_config, dataset_class): input_size = (475, 475) model_path = args.snapshot dataset_instance = dataset_class(data_path=dt_config.DATA_PATH) num_classes = dataset_instance.num_classes model = PSPNet(num_classes=num_classes) model.load_state_dict(torch.load(model_path)["state_dict"]) model.eval() img = cv2.imread(args.image_path) processed_img = cv2.resize(img, input_size) overlay = np.copy(processed_img) processed_img = processed_img / 255.0 processed_img = torch.tensor( processed_img.transpose(2, 0, 1)[np.newaxis, :]).float() if torch.cuda.is_available(): model = model.cuda() processed_img = processed_img.cuda() output = model(processed_img)[0] mask = output.data.max(1)[1].cpu().numpy().reshape(475, 475) color_mask = np.array(dataset_instance.colors)[mask] alpha = args.alpha overlay = (((1 - alpha) * overlay) + (alpha * color_mask)).astype("uint8") overlay = cv2.resize(overlay, (img.shape[1], img.shape[0])) cv2.imwrite("result.jpg", overlay)
def get_model(criterion=None, auxiliary_loss=False, auxloss_weight=0): return PSPNet( encoder_name='dilated_resnet50', encoder_weights='imagenet', classes=19, auxiliary_loss=auxiliary_loss, auxloss_weight=auxloss_weight, criterion=criterion)
def __init__(self, n_classes, psp_size=2048, psp_bins=(1, 2, 3, 6), dropout=0.1, backbone='resnet50', **kwargs): super().__init__() self.save_hyperparameters() self.pspnet = PSPNet(n_classes=n_classes, psp_size=psp_size, psp_bins=psp_bins, dropout=dropout, backbone=Backbone(backbone, pretrained=True)) self.ckpts_index = 0
def main(): args = parse_arguments() # Dataset used for training the model MEAN = [0.45734706, 0.43338275, 0.40058118] STD = [0.23965294, 0.23532275, 0.2398498] to_tensor = transforms.ToTensor() normalize = transforms.Normalize(MEAN, STD) num_classes = 2 palette = [0, 0, 0, 128, 0, 128] # Model model = PSPNet(num_classes=num_classes, backbone='resnet18') availble_gpus = list(range(torch.cuda.device_count())) device = torch.device('cuda:0' if len(availble_gpus) > 0 else 'cpu') checkpoint = torch.load(args.model) if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys(): checkpoint = checkpoint['state_dict'] if 'module' in list(checkpoint.keys())[0] and not isinstance( model, torch.nn.DataParallel): model = torch.nn.DataParallel(model) model.load_state_dict(checkpoint) model.to(device) model.eval() if not os.path.exists('outputs'): os.makedirs('outputs') image_files = sorted(glob(os.path.join(args.images, f'*.{args.extension}'))) with torch.no_grad(): tbar = tqdm(image_files, ncols=100) for img_file in tbar: image = Image.open(img_file).convert('RGB') image = image.resize((480, 320)) input = normalize(to_tensor(image)).unsqueeze(0) print(input.size()) t1 = time.time() prediction = model(input.to(device)) prediction = prediction.squeeze(0).cpu().numpy() print(time.time() - t1) prediction = F.softmax(torch.from_numpy(prediction), dim=0).argmax(0).cpu().numpy() save_images(image, prediction, args.output, img_file, palette)
def main(): batch_size = 8 net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda() snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth' net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot))) net.eval() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) transform = transforms.Compose([ expanded_transform.FreeScale((512, 1024)), transforms.ToTensor(), transforms.Normalize(*mean_std) ]) restore = transforms.Compose([ expanded_transform.DeNormalize(*mean_std), transforms.ToPILImage() ]) lsun_path = '/home/b3-542/LSUN' dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform) dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True) if not os.path.exists(test_results_path): os.mkdir(test_results_path) for vi, data in enumerate(dataloader, 0): inputs, labels = data inputs = Variable(inputs, volatile=True).cuda() outputs = net(inputs) prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy() for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)): pil_input = restore(tensor[0]) pil_output = colorize_mask(tensor[1]) pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx))) pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx))) print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
def test(): if args.choose_net == "Unet": model = my_unet.UNet(3, 1).to(device) if args.choose_net == "My_Unet": model = my_unet.My_Unet2(3, 1).to(device) elif args.choose_net == "Enet": model = enet.ENet(num_classes=13).to(device) elif args.choose_net == "Segnet": model = segnet.SegNet(3, 1).to(device) elif args.choose_net == "CascadNet": model = my_cascadenet.CascadeNet(3, 1).to(device) elif args.choose_net == "my_drsnet_A": model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_B": model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_C": model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_A_direct_skip": model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3, out_ch=1).to(device) elif args.choose_net == "SEResNet": model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnext_unet": model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnet50_unet": model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3, out_ch=1).to(device) elif args.choose_net == "unet_res34": model = unet_res34.Resnet_Unet(in_ch=3, out_ch=1).to(device) elif args.choose_net == "dfanet": ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]] model = dfanet.DFANet(ch_cfg, 3, 1).to(device) elif args.choose_net == "cgnet": model = cgnet.Context_Guided_Network(1).to(device) elif args.choose_net == "lednet": model = lednet.Net(num_classes=1).to(device) elif args.choose_net == "bisenet": model = bisenet.BiSeNet(1, 'resnet18').to(device) elif args.choose_net == "espnet": model = espnet.ESPNet(classes=1).to(device) elif args.choose_net == "pspnet": model = pspnet.PSPNet(1).to(device) elif args.choose_net == "fddwnet": model = fddwnet.Net(classes=1).to(device) elif args.choose_net == "contextnet": model = contextnet.ContextNet(classes=1).to(device) elif args.choose_net == "linknet": model = linknet.LinkNet(classes=1).to(device) elif args.choose_net == "edanet": model = edanet.EDANet(classes=1).to(device) elif args.choose_net == "erfnet": model = erfnet.ERFNet(classes=1).to(device) dsize = (1, 3, 128, 192) inputs = torch.randn(dsize).to(device) total_ops, total_params = profile(model, (inputs, ), verbose=False) print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3))) model.load_state_dict(torch.load(args.weight)) liver_dataset = LiverDataset("data/val_camvid", transform=x_transform, target_transform=y_transform) dataloaders = DataLoader(liver_dataset) # batch_size默认为1 model.eval() metric = SegmentationMetric(13) # import matplotlib.pyplot as plt # plt.ion() multiclass = 1 mean_acc, mean_miou = [], [] alltime = 0.0 with torch.no_grad(): for x, y_label in dataloaders: x = x.to(device) start = time.time() y = model(x) usingtime = time.time() - start alltime = alltime + usingtime if multiclass == 1: # predict输出处理: # https://www.cnblogs.com/ljwgis/p/12313047.html y = F.sigmoid(y) y = y.cpu() # y = torch.squeeze(y).numpy() y = torch.argmax(y.squeeze(0), dim=0).data.numpy() print(y.max(), y.min()) # y_label = y_label[0] y_label = torch.squeeze(y_label).numpy() else: y = y.cpu() y = torch.squeeze(y).numpy() y_label = torch.squeeze(y_label).numpy() # img_y = y*127.5 if args.choose_net == "Unet": y = (y > 0.5) elif args.choose_net == "My_Unet": y = (y > 0.5) elif args.choose_net == "Enet": y = (y > 0.5) elif args.choose_net == "Segnet": y = (y > 0.5) elif args.choose_net == "Scnn": y = (y > 0.5) elif args.choose_net == "CascadNet": y = (y > 0.8) elif args.choose_net == "my_drsnet_A": y = (y > 0.5) elif args.choose_net == "my_drsnet_B": y = (y > 0.5) elif args.choose_net == "my_drsnet_C": y = (y > 0.5) elif args.choose_net == "my_drsnet_A_direct_skip": y = (y > 0.5) elif args.choose_net == "SEResNet": y = (y > 0.5) elif args.choose_net == "resnext_unet": y = (y > 0.5) elif args.choose_net == "resnet50_unet": y = (y > 0.5) elif args.choose_net == "unet_res34": y = (y > 0.5) elif args.choose_net == "dfanet": y = (y > 0.5) elif args.choose_net == "cgnet": y = (y > 0.5) elif args.choose_net == "lednet": y = (y > 0.5) elif args.choose_net == "bisenet": y = (y > 0.5) elif args.choose_net == "pspnet": y = (y > 0.5) elif args.choose_net == "fddwnet": y = (y > 0.5) elif args.choose_net == "contextnet": y = (y > 0.5) elif args.choose_net == "linknet": y = (y > 0.5) elif args.choose_net == "edanet": y = (y > 0.5) elif args.choose_net == "erfnet": y = (y > 0.5) img_y = y.astype(int).squeeze() print(y_label.shape, img_y.shape) image = np.concatenate((img_y, y_label)) y_label = y_label.astype(int) metric.addBatch(img_y, y_label) acc = metric.classPixelAccuracy() mIoU = metric.meanIntersectionOverUnion() # confusionMatrix=metric.genConfusionMatrix(img_y, y_label) mean_acc.append(acc[1]) mean_miou.append(mIoU) # print(acc, mIoU,confusionMatrix) print(acc, mIoU) plt.imshow(image * 5) plt.pause(0.1) plt.show() # 计算时需封印acc和miou计算部分 print("Took ", alltime, "seconds") print("Took", alltime / 638.0, "s/perimage") print("FPS", 1 / (alltime / 638.0)) print("average acc:%0.6f average miou:%0.6f" % (np.mean(mean_acc), np.mean(mean_miou)))
def train(): if args.choose_net == "Unet": model = my_unet.UNet(3, 1).to(device) if args.choose_net == "My_Unet": model = my_unet.My_Unet2(3, 1).to(device) elif args.choose_net == "Enet": model = enet.ENet(num_classes=13).to(device) elif args.choose_net == "Segnet": model = segnet.SegNet(3, 13).to(device) elif args.choose_net == "CascadNet": model = my_cascadenet.CascadeNet(3, 1).to(device) elif args.choose_net == "my_drsnet_A": model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_B": model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_C": model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device) elif args.choose_net == "my_drsnet_A_direct_skip": model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3, out_ch=1).to(device) elif args.choose_net == "SEResNet": model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnext_unet": model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device) elif args.choose_net == "resnet50_unet": model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3, out_ch=1).to(device) elif args.choose_net == "unet_nest": model = unet_nest.UNet_Nested(3, 2).to(device) elif args.choose_net == "unet_res34": model = unet_res34.Resnet_Unet(3, 1).to(device) elif args.choose_net == "trangle_net": model = mytrangle_net.trangle_net(3, 1).to(device) elif args.choose_net == "dfanet": ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]] model = dfanet.DFANet(ch_cfg, 3, 1).to(device) elif args.choose_net == "lednet": model = lednet.Net(num_classes=1).to(device) elif args.choose_net == "cgnet": model = cgnet.Context_Guided_Network(classes=1).to(device) elif args.choose_net == "pspnet": model = pspnet.PSPNet(1).to(device) elif args.choose_net == "bisenet": model = bisenet.BiSeNet(1, 'resnet18').to(device) elif args.choose_net == "espnet": model = espnet.ESPNet(classes=1).to(device) elif args.choose_net == "fddwnet": model = fddwnet.Net(classes=1).to(device) elif args.choose_net == "contextnet": model = contextnet.ContextNet(classes=1).to(device) elif args.choose_net == "linknet": model = linknet.LinkNet(classes=1).to(device) elif args.choose_net == "edanet": model = edanet.EDANet(classes=1).to(device) elif args.choose_net == "erfnet": model = erfnet.ERFNet(classes=1).to(device) from collections import OrderedDict loadpretrained = 0 # 0:no loadpretrained model # 1:loadpretrained model to original network # 2:loadpretrained model to new network if loadpretrained == 1: model.load_state_dict(torch.load(args.weight)) elif loadpretrained == 2: model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device) model_dict = model.state_dict() pretrained_dict = torch.load(args.weight) pretrained_dict = { k: v for k, v in pretrained_dict.items() if k in model_dict } model_dict.update(pretrained_dict) model.load_state_dict(model_dict) # model.load_state_dict(torch.load(args.weight)) # pretrained_dict = {k: v for k, v in model.items() if k in model} # filter out unnecessary keys # model.update(pretrained_dict) # model.load_state_dict(model) # 计算模型参数量和计算量FLOPs dsize = (1, 3, 128, 192) inputs = torch.randn(dsize).to(device) total_ops, total_params = profile(model, (inputs, ), verbose=False) print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3))) batch_size = args.batch_size # 加载数据集 liver_dataset = LiverDataset("data/train_camvid/", transform=x_transform, target_transform=y_transform) len_img = liver_dataset.__len__() dataloader = DataLoader(liver_dataset, batch_size=batch_size, shuffle=True, num_workers=24) # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用 # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度 # 梯度下降 # optimizer = optim.Adam(model.parameters()) # model.parameters():Returns an iterator over module parameters # # Observe that all parameters are being optimized optimizer = optim.SGD(model.parameters(), lr=args.lr, momentum=0.9, weight_decay=0.0001) # 每n个epoches来一次余弦退火 cosine_lr_scheduler = lr_scheduler.CosineAnnealingLR( optimizer, T_max=10 * int(len_img / batch_size), eta_min=0.00001) multiclass = 1 if multiclass == 1: # 损失函数 class_weights = np.array([ 0., 6.3005947, 4.31063664, 34.09234699, 50.49834979, 3.88280945, 50.49834979, 8.91626081, 47.58477105, 29.41289083, 18.95706775, 37.84558871, 39.3477858 ]) #camvid # class_weights = weighing(dataloader, 13, c=1.02) class_weights = torch.from_numpy(class_weights).float().to(device) criterion = torch.nn.CrossEntropyLoss(weight=class_weights) # criterion = LovaszLossSoftmax() # criterion = torch.nn.MSELoss() train_modelmulticlasses(model, criterion, optimizer, dataloader, cosine_lr_scheduler) else: # 损失函数 # criterion = LovaszLossHinge() # weights=[0.2] # weights=torch.Tensor(weights).to(device) # # criterion = torch.nn.CrossEntropyLoss(weight=weights) criterion = torch.nn.BCELoss() # criterion =focal_loss.FocalLoss(1) train_model(model, criterion, optimizer, dataloader, cosine_lr_scheduler)
def main(): net = PSPNet(num_classes=num_classes) if len(args['snapshot']) == 0: # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth'))) curr_epoch = 1 args['best_record'] = {'epoch': 0, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0, 'fwavacc': 0} else: print('training resumes from ' + args['snapshot']) net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot']))) split_snapshot = args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) + 1 args['best_record'] = {'epoch': int(split_snapshot[1]), 'iter': int(split_snapshot[3]), 'val_loss': float(split_snapshot[5]), 'acc': float(split_snapshot[7]), 'acc_cls': float(split_snapshot[9]),'mean_iu': float(split_snapshot[11]), 'fwavacc': float(split_snapshot[13])} net.cuda().train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_joint_transform = joint_transforms.Compose([ joint_transforms.Scale(args['longer_size']), joint_transforms.RandomRotate(10), joint_transforms.RandomHorizontallyFlip() ]) sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], ignore_label) train_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) val_input_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = extended_transforms.MaskToTensor() visualize = standard_transforms.Compose([ standard_transforms.Scale(args['val_img_display_size']), standard_transforms.ToTensor() ]) train_set = Retinaimages('training', joint_transform=train_joint_transform, sliding_crop=sliding_crop, transform=train_input_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=2, shuffle=True) val_set = Retinaimages('validate', transform=val_input_transform, sliding_crop=sliding_crop, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=1, num_workers=2, shuffle=False) criterion = CrossEntropyLoss2d(size_average=True).cuda() optimizer = optim.SGD([ {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'], 'lr': 2 * args['lr']}, {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'], 'lr': args['lr'], 'weight_decay': args['weight_decay']} ], momentum=args['momentum'], nesterov=True) if len(args['snapshot']) > 0: optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * args['lr'] optimizer.param_groups[1]['lr'] = args['lr'] check_mkdir(ckpt_path) check_mkdir(os.path.join(ckpt_path, exp_name)) open(os.path.join(ckpt_path, exp_name, "_1" + '.txt'), 'w').write(str(args) + '\n\n') train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize, val_set)
# output_stride=16, # num_classes=2, # pretrained_backbone=None, # ) # # BiSeNet # model = BiSeNet( # backbone='resnet18', # num_classes=2, # pretrained_backbone=None, # ) # PSPNet model = PSPNet( backbone='resnet18', num_classes=2, pretrained_backbone=None, ) # # ICNet # model = ICNet( # backbone='resnet18', # num_classes=2, # pretrained_backbone=None, # ) #------------------------------------------------------------------------------ # Summary network #------------------------------------------------------------------------------ model.train() model.summary(input_shape=(3, args.input_sz, args.input_sz), device='cpu')
def train_process(args, dt_config, dataset_class, data_transform_class): # input_size = (params["img_h"], params["img_w"]) input_size = (475, 475) num_classes = 20 # transforms = [ # OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.5), # # OneOf( # # [ # # MedianBlur(blur_limit=3), # # GaussianBlur(blur_limit=3), # # MotionBlur(blur_limit=3), # # ], # # p=0.1, # # ), # RandomGamma(gamma_limit=(80, 120), p=0.5), # RandomBrightnessContrast(p=0.5), # HueSaturationValue( # hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=10, p=0.5 # ), # ChannelShuffle(p=0.5), # HorizontalFlip(p=0.5), # Cutout(num_holes=2, max_w_size=40, max_h_size=40, p=0.5), # Rotate(limit=20, p=0.5, border_mode=0), # ] data_transform = data_transform_class(num_classes=num_classes, input_size=input_size) train_dataset = dataset_class( data_path=dt_config.DATA_PATH, phase="train", transform=data_transform, ) val_dataset = dataset_class( data_path=dt_config.DATA_PATH, phase="val", transform=data_transform, ) train_data_loader = DataLoader( train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=4, drop_last=True, ) val_data_loader = DataLoader( val_dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, drop_last=True, ) data_loaders_dict = {"train": train_data_loader, "val": val_data_loader} tblogger = SummaryWriter(dt_config.LOG_PATH) model = PSPNet(num_classes=num_classes) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") criterion = PSPLoss() optimizer = torch.optim.SGD( [ { "params": model.feature_conv.parameters(), "lr": 1e-3 }, { "params": model.feature_res_1.parameters(), "lr": 1e-3 }, { "params": model.feature_res_2.parameters(), "lr": 1e-3 }, { "params": model.feature_dilated_res_1.parameters(), "lr": 1e-3 }, { "params": model.feature_dilated_res_2.parameters(), "lr": 1e-3 }, { "params": model.pyramid_pooling.parameters(), "lr": 1e-3 }, { "params": model.decode_feature.parameters(), "lr": 1e-2 }, { "params": model.aux.parameters(), "lr": 1e-2 }, ], momentum=0.9, weight_decay=0.0001, ) def _lambda_epoch(epoch): import math max_epoch = args.num_epoch return math.pow((1 - epoch / max_epoch), 0.9) scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=_lambda_epoch) trainer = Trainer( model=model, criterion=criterion, metric_func=None, optimizer=optimizer, num_epochs=args.num_epoch, save_period=args.save_period, config=dt_config, data_loaders_dict=data_loaders_dict, scheduler=scheduler, device=device, dataset_name_base=train_dataset.__name__, batch_multiplier=args.batch_multiplier, logger=tblogger, ) if args.snapshot and os.path.isfile(args.snapshot): trainer.resume_checkpoint(args.snapshot) with torch.autograd.set_detect_anomaly(True): trainer.train() tblogger.close()
def main(): net = PSPNet(num_classes=num_classes, input_size=train_args['input_size']).cuda() if len(train_args['snapshot']) == 0: curr_epoch = 0 else: print 'training resumes from ' + train_args['snapshot'] net.load_state_dict( torch.load( os.path.join(ckpt_path, exp_name, train_args['snapshot']))) split_snapshot = train_args['snapshot'].split('_') curr_epoch = int(split_snapshot[1]) train_record['best_val_loss'] = float(split_snapshot[3]) train_record['corr_mean_iu'] = float(split_snapshot[6]) train_record['corr_epoch'] = curr_epoch net.train() mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) train_simul_transform = simul_transforms.Compose([ simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)), simul_transforms.RandomCrop(train_args['input_size']), simul_transforms.RandomHorizontallyFlip() ]) val_simul_transform = simul_transforms.Compose([ simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)), simul_transforms.CenterCrop(train_args['input_size']) ]) img_transform = standard_transforms.Compose([ standard_transforms.ToTensor(), standard_transforms.Normalize(*mean_std) ]) target_transform = standard_transforms.Compose([ expanded_transforms.MaskToTensor(), expanded_transforms.ChangeLabel(ignored_label, num_classes - 1) ]) restore_transform = standard_transforms.Compose([ expanded_transforms.DeNormalize(*mean_std), standard_transforms.ToPILImage() ]) train_set = CityScapes('train', simul_transform=train_simul_transform, transform=img_transform, target_transform=target_transform) train_loader = DataLoader(train_set, batch_size=train_args['batch_size'], num_workers=16, shuffle=True) val_set = CityScapes('val', simul_transform=val_simul_transform, transform=img_transform, target_transform=target_transform) val_loader = DataLoader(val_set, batch_size=val_args['batch_size'], num_workers=16, shuffle=False) weight = torch.ones(num_classes) weight[num_classes - 1] = 0 criterion = CrossEntropyLoss2d(weight).cuda() # don't use weight_decay for bias optimizer = optim.SGD([{ 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and ( 'ppm' in name or 'final' in name or 'aux_logits' in name) ], 'lr': 2 * train_args['new_lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and ( 'ppm' in name or 'final' in name or 'aux_logits' in name) ], 'lr': train_args['new_lr'], 'weight_decay': train_args['weight_decay'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] == 'bias' and not ('ppm' in name or 'final' in name or 'aux_logits' in name) ], 'lr': 2 * train_args['pretrained_lr'] }, { 'params': [ param for name, param in net.named_parameters() if name[-4:] != 'bias' and not ('ppm' in name or 'final' in name or 'aux_logits' in name) ], 'lr': train_args['pretrained_lr'], 'weight_decay': train_args['weight_decay'] }], momentum=0.9, nesterov=True) if len(train_args['snapshot']) > 0: optimizer.load_state_dict( torch.load(os.path.join(ckpt_path, 'opt_' + train_args['snapshot']))) optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr'] optimizer.param_groups[1]['lr'] = train_args['new_lr'] optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr'] optimizer.param_groups[3]['lr'] = train_args['pretrained_lr'] if not os.path.exists(ckpt_path): os.mkdir(ckpt_path) if not os.path.exists(os.path.join(ckpt_path, exp_name)): os.mkdir(os.path.join(ckpt_path, exp_name)) for epoch in range(curr_epoch, train_args['epoch_num']): train(train_loader, net, criterion, optimizer, epoch) validate(val_loader, net, criterion, optimizer, epoch, restore_transform)
def main(): model = PSPNet(num_classes=12) input = torch.randn(1, 3, 475, 475) output, output_aux = model(input) print(output.shape) print(output_aux.shape)
class_name = class_list[0] print('Predict {0} with epoch {1}'.format(model_name, step_num)) result_dir = './result/{0}/{1}/'.format(model_name, class_name) if not os.path.exists(result_dir): os.makedirs(result_dir) if model_name == 'UNet': model = UNet(2, in_channels=3) if model_name == 'FCN8': model = FCN8(2) if model_name == 'PSPNet': model = PSPNet(2) if model_name == 'UperNet': model = UperNet(2) if model_name == 'CC_UNet': model = CC_UNet(2) if model_name == 'A_UNet': model = A_UNet(2) device = torch.device('cuda:0') state_dict = torch.load('./checkpoints/{0}_{1}.pth'.format( model_name, step_num))