Exemple #1
0
 def init_model(self, ):
     if (self.model_name == "PSPNet"):
         self.model, self.model_out_layer_names = PSPNet(
             backbone_name=self.backbone,
             input_shape=self.input_shape,
             classes=self.n_classes,
             encoder_weights=self.pretrained_encoder_weights,
             encoder_freeze=self.transfer_learning,
             training=self.training)
         try:
             # The model weights (that are considered the best) are loaded into the model.
             self.model.load_weights(
                 '/home/essys/projects/segmentation/checkpoints/' +
                 self.train_id + "/")
         except:
             print('could not find saved model')
     if (self.model_name == "Bisenet_V2"):
         self.model = BisenetV2Model(
             train_op=optimizers.Adam(self.lr),
             input_shape=self.input_shape,
             classes=self.n_classes,
             batch_size=self.batch_size,
             class_weights=self.class_weights
         )  #BisenetV2(input_shape=self.input_shape,classes=self.n_classes)
         try:
             checkpoint_dir = '/home/essys/projects/segmentation/checkpoints/' + self.train_id + "/"
             latest = tf.train.latest_checkpoint(checkpoint_dir)
             print(latest + " is found model\n\n")
             # The model weights (that are considered the best) are loaded into the model.
             self.model.model.load_weights(latest)
         except:
             print('could not find saved model')
Exemple #2
0
def test_one_image(args, dt_config, dataset_class):
    input_size = (475, 475)
    model_path = args.snapshot
    dataset_instance = dataset_class(data_path=dt_config.DATA_PATH)
    num_classes = dataset_instance.num_classes
    model = PSPNet(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path)["state_dict"])
    model.eval()

    img = cv2.imread(args.image_path)
    processed_img = cv2.resize(img, input_size)
    overlay = np.copy(processed_img)
    processed_img = processed_img / 255.0
    processed_img = torch.tensor(
        processed_img.transpose(2, 0, 1)[np.newaxis, :]).float()
    if torch.cuda.is_available():
        model = model.cuda()
        processed_img = processed_img.cuda()
    output = model(processed_img)[0]
    mask = output.data.max(1)[1].cpu().numpy().reshape(475, 475)
    color_mask = np.array(dataset_instance.colors)[mask]
    alpha = args.alpha
    overlay = (((1 - alpha) * overlay) + (alpha * color_mask)).astype("uint8")
    overlay = cv2.resize(overlay, (img.shape[1], img.shape[0]))
    cv2.imwrite("result.jpg", overlay)
Exemple #3
0
def get_model(criterion=None, auxiliary_loss=False, auxloss_weight=0):
  return PSPNet(
      encoder_name='dilated_resnet50',
      encoder_weights='imagenet',
      classes=19,
      auxiliary_loss=auxiliary_loss,
      auxloss_weight=auxloss_weight,
      criterion=criterion)
Exemple #4
0
 def __init__(self,
              n_classes,
              psp_size=2048,
              psp_bins=(1, 2, 3, 6),
              dropout=0.1,
              backbone='resnet50',
              **kwargs):
     super().__init__()
     self.save_hyperparameters()
     self.pspnet = PSPNet(n_classes=n_classes,
                          psp_size=psp_size,
                          psp_bins=psp_bins,
                          dropout=dropout,
                          backbone=Backbone(backbone, pretrained=True))
     self.ckpts_index = 0
Exemple #5
0
def main():
    args = parse_arguments()

    # Dataset used for training the model
    MEAN = [0.45734706, 0.43338275, 0.40058118]
    STD = [0.23965294, 0.23532275, 0.2398498]

    to_tensor = transforms.ToTensor()
    normalize = transforms.Normalize(MEAN, STD)
    num_classes = 2
    palette = [0, 0, 0, 128, 0, 128]

    # Model
    model = PSPNet(num_classes=num_classes, backbone='resnet18')
    availble_gpus = list(range(torch.cuda.device_count()))
    device = torch.device('cuda:0' if len(availble_gpus) > 0 else 'cpu')

    checkpoint = torch.load(args.model)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys():
        checkpoint = checkpoint['state_dict']
    if 'module' in list(checkpoint.keys())[0] and not isinstance(
            model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()

    if not os.path.exists('outputs'):
        os.makedirs('outputs')

    image_files = sorted(glob(os.path.join(args.images,
                                           f'*.{args.extension}')))
    with torch.no_grad():
        tbar = tqdm(image_files, ncols=100)
        for img_file in tbar:
            image = Image.open(img_file).convert('RGB')
            image = image.resize((480, 320))
            input = normalize(to_tensor(image)).unsqueeze(0)
            print(input.size())
            t1 = time.time()
            prediction = model(input.to(device))
            prediction = prediction.squeeze(0).cpu().numpy()
            print(time.time() - t1)
            prediction = F.softmax(torch.from_numpy(prediction),
                                   dim=0).argmax(0).cpu().numpy()
            save_images(image, prediction, args.output, img_file, palette)
Exemple #6
0
def main():
    batch_size = 8

    net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda()
    snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth'
    net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transform = transforms.Compose([
        expanded_transform.FreeScale((512, 1024)),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    restore = transforms.Compose([
        expanded_transform.DeNormalize(*mean_std),
        transforms.ToPILImage()
    ])

    lsun_path = '/home/b3-542/LSUN'

    dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True)

    if not os.path.exists(test_results_path):
        os.mkdir(test_results_path)

    for vi, data in enumerate(dataloader, 0):
        inputs, labels = data
        inputs = Variable(inputs, volatile=True).cuda()
        outputs = net(inputs)

        prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy()

        for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)):
            pil_input = restore(tensor[0])
            pil_output = colorize_mask(tensor[1])
            pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx)))
            pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx)))
            print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
Exemple #7
0
def test():
    if args.choose_net == "Unet":
        model = my_unet.UNet(3, 1).to(device)
    if args.choose_net == "My_Unet":
        model = my_unet.My_Unet2(3, 1).to(device)
    elif args.choose_net == "Enet":
        model = enet.ENet(num_classes=13).to(device)
    elif args.choose_net == "Segnet":
        model = segnet.SegNet(3, 1).to(device)
    elif args.choose_net == "CascadNet":
        model = my_cascadenet.CascadeNet(3, 1).to(device)

    elif args.choose_net == "my_drsnet_A":
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_B":
        model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_C":
        model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_A_direct_skip":
        model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3,
                                                       out_ch=1).to(device)
    elif args.choose_net == "SEResNet":
        model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device)

    elif args.choose_net == "resnext_unet":
        model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "resnet50_unet":
        model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3,
                                                      out_ch=1).to(device)
    elif args.choose_net == "unet_res34":
        model = unet_res34.Resnet_Unet(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "dfanet":
        ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]]
        model = dfanet.DFANet(ch_cfg, 3, 1).to(device)
    elif args.choose_net == "cgnet":
        model = cgnet.Context_Guided_Network(1).to(device)
    elif args.choose_net == "lednet":
        model = lednet.Net(num_classes=1).to(device)
    elif args.choose_net == "bisenet":
        model = bisenet.BiSeNet(1, 'resnet18').to(device)
    elif args.choose_net == "espnet":
        model = espnet.ESPNet(classes=1).to(device)
    elif args.choose_net == "pspnet":
        model = pspnet.PSPNet(1).to(device)
    elif args.choose_net == "fddwnet":
        model = fddwnet.Net(classes=1).to(device)
    elif args.choose_net == "contextnet":
        model = contextnet.ContextNet(classes=1).to(device)
    elif args.choose_net == "linknet":
        model = linknet.LinkNet(classes=1).to(device)
    elif args.choose_net == "edanet":
        model = edanet.EDANet(classes=1).to(device)
    elif args.choose_net == "erfnet":
        model = erfnet.ERFNet(classes=1).to(device)
    dsize = (1, 3, 128, 192)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(model, (inputs, ), verbose=False)
    print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3)))

    model.load_state_dict(torch.load(args.weight))
    liver_dataset = LiverDataset("data/val_camvid",
                                 transform=x_transform,
                                 target_transform=y_transform)
    dataloaders = DataLoader(liver_dataset)  # batch_size默认为1
    model.eval()

    metric = SegmentationMetric(13)
    # import matplotlib.pyplot as plt
    # plt.ion()
    multiclass = 1
    mean_acc, mean_miou = [], []

    alltime = 0.0
    with torch.no_grad():
        for x, y_label in dataloaders:
            x = x.to(device)
            start = time.time()
            y = model(x)
            usingtime = time.time() - start
            alltime = alltime + usingtime

            if multiclass == 1:
                # predict输出处理:
                # https://www.cnblogs.com/ljwgis/p/12313047.html
                y = F.sigmoid(y)
                y = y.cpu()
                # y = torch.squeeze(y).numpy()
                y = torch.argmax(y.squeeze(0), dim=0).data.numpy()
                print(y.max(), y.min())
                # y_label = y_label[0]
                y_label = torch.squeeze(y_label).numpy()
            else:
                y = y.cpu()
                y = torch.squeeze(y).numpy()
                y_label = torch.squeeze(y_label).numpy()

                # img_y = y*127.5

                if args.choose_net == "Unet":
                    y = (y > 0.5)
                elif args.choose_net == "My_Unet":
                    y = (y > 0.5)
                elif args.choose_net == "Enet":
                    y = (y > 0.5)
                elif args.choose_net == "Segnet":
                    y = (y > 0.5)
                elif args.choose_net == "Scnn":
                    y = (y > 0.5)
                elif args.choose_net == "CascadNet":
                    y = (y > 0.8)

                elif args.choose_net == "my_drsnet_A":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_B":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_C":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_A_direct_skip":
                    y = (y > 0.5)
                elif args.choose_net == "SEResNet":
                    y = (y > 0.5)

                elif args.choose_net == "resnext_unet":
                    y = (y > 0.5)
                elif args.choose_net == "resnet50_unet":
                    y = (y > 0.5)
                elif args.choose_net == "unet_res34":
                    y = (y > 0.5)
                elif args.choose_net == "dfanet":
                    y = (y > 0.5)
                elif args.choose_net == "cgnet":
                    y = (y > 0.5)
                elif args.choose_net == "lednet":
                    y = (y > 0.5)
                elif args.choose_net == "bisenet":
                    y = (y > 0.5)
                elif args.choose_net == "pspnet":
                    y = (y > 0.5)
                elif args.choose_net == "fddwnet":
                    y = (y > 0.5)
                elif args.choose_net == "contextnet":
                    y = (y > 0.5)
                elif args.choose_net == "linknet":
                    y = (y > 0.5)
                elif args.choose_net == "edanet":
                    y = (y > 0.5)
                elif args.choose_net == "erfnet":
                    y = (y > 0.5)

            img_y = y.astype(int).squeeze()
            print(y_label.shape, img_y.shape)
            image = np.concatenate((img_y, y_label))

            y_label = y_label.astype(int)
            metric.addBatch(img_y, y_label)
            acc = metric.classPixelAccuracy()
            mIoU = metric.meanIntersectionOverUnion()
            # confusionMatrix=metric.genConfusionMatrix(img_y, y_label)
            mean_acc.append(acc[1])
            mean_miou.append(mIoU)
            # print(acc, mIoU,confusionMatrix)
            print(acc, mIoU)
            plt.imshow(image * 5)
            plt.pause(0.1)
            plt.show()
    # 计算时需封印acc和miou计算部分

    print("Took ", alltime, "seconds")
    print("Took", alltime / 638.0, "s/perimage")
    print("FPS", 1 / (alltime / 638.0))
    print("average acc:%0.6f  average miou:%0.6f" %
          (np.mean(mean_acc), np.mean(mean_miou)))
Exemple #8
0
def train():
    if args.choose_net == "Unet":
        model = my_unet.UNet(3, 1).to(device)
    if args.choose_net == "My_Unet":
        model = my_unet.My_Unet2(3, 1).to(device)
    elif args.choose_net == "Enet":
        model = enet.ENet(num_classes=13).to(device)
    elif args.choose_net == "Segnet":
        model = segnet.SegNet(3, 13).to(device)
    elif args.choose_net == "CascadNet":
        model = my_cascadenet.CascadeNet(3, 1).to(device)

    elif args.choose_net == "my_drsnet_A":
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_B":
        model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_C":
        model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_A_direct_skip":
        model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3,
                                                       out_ch=1).to(device)
    elif args.choose_net == "SEResNet":
        model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device)

    elif args.choose_net == "resnext_unet":
        model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "resnet50_unet":
        model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3,
                                                      out_ch=1).to(device)
    elif args.choose_net == "unet_nest":
        model = unet_nest.UNet_Nested(3, 2).to(device)
    elif args.choose_net == "unet_res34":
        model = unet_res34.Resnet_Unet(3, 1).to(device)
    elif args.choose_net == "trangle_net":
        model = mytrangle_net.trangle_net(3, 1).to(device)
    elif args.choose_net == "dfanet":
        ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]]
        model = dfanet.DFANet(ch_cfg, 3, 1).to(device)
    elif args.choose_net == "lednet":
        model = lednet.Net(num_classes=1).to(device)
    elif args.choose_net == "cgnet":
        model = cgnet.Context_Guided_Network(classes=1).to(device)
    elif args.choose_net == "pspnet":
        model = pspnet.PSPNet(1).to(device)
    elif args.choose_net == "bisenet":
        model = bisenet.BiSeNet(1, 'resnet18').to(device)
    elif args.choose_net == "espnet":
        model = espnet.ESPNet(classes=1).to(device)
    elif args.choose_net == "fddwnet":
        model = fddwnet.Net(classes=1).to(device)
    elif args.choose_net == "contextnet":
        model = contextnet.ContextNet(classes=1).to(device)
    elif args.choose_net == "linknet":
        model = linknet.LinkNet(classes=1).to(device)
    elif args.choose_net == "edanet":
        model = edanet.EDANet(classes=1).to(device)
    elif args.choose_net == "erfnet":
        model = erfnet.ERFNet(classes=1).to(device)

    from collections import OrderedDict

    loadpretrained = 0
    # 0:no loadpretrained model
    # 1:loadpretrained model to original network
    # 2:loadpretrained model to new network
    if loadpretrained == 1:
        model.load_state_dict(torch.load(args.weight))

    elif loadpretrained == 2:
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
        model_dict = model.state_dict()
        pretrained_dict = torch.load(args.weight)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        # model.load_state_dict(torch.load(args.weight))
        # pretrained_dict = {k: v for k, v in model.items() if k in model}  # filter out unnecessary keys
        # model.update(pretrained_dict)
        # model.load_state_dict(model)

    # 计算模型参数量和计算量FLOPs
    dsize = (1, 3, 128, 192)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(model, (inputs, ), verbose=False)
    print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3)))
    batch_size = args.batch_size

    # 加载数据集
    liver_dataset = LiverDataset("data/train_camvid/",
                                 transform=x_transform,
                                 target_transform=y_transform)
    len_img = liver_dataset.__len__()
    dataloader = DataLoader(liver_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=24)

    # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
    # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch
    # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用
    # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度

    # 梯度下降
    # optimizer = optim.Adam(model.parameters())  # model.parameters():Returns an iterator over module parameters
    # # Observe that all parameters are being optimized

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=0.0001)

    # 每n个epoches来一次余弦退火
    cosine_lr_scheduler = lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=10 * int(len_img / batch_size), eta_min=0.00001)

    multiclass = 1
    if multiclass == 1:
        # 损失函数
        class_weights = np.array([
            0., 6.3005947, 4.31063664, 34.09234699, 50.49834979, 3.88280945,
            50.49834979, 8.91626081, 47.58477105, 29.41289083, 18.95706775,
            37.84558871, 39.3477858
        ])  #camvid
        # class_weights = weighing(dataloader, 13, c=1.02)
        class_weights = torch.from_numpy(class_weights).float().to(device)
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        # criterion = LovaszLossSoftmax()
        # criterion = torch.nn.MSELoss()
        train_modelmulticlasses(model, criterion, optimizer, dataloader,
                                cosine_lr_scheduler)
    else:
        # 损失函数
        # criterion = LovaszLossHinge()
        # weights=[0.2]
        # weights=torch.Tensor(weights).to(device)
        # # criterion = torch.nn.CrossEntropyLoss(weight=weights)
        criterion = torch.nn.BCELoss()
        # criterion =focal_loss.FocalLoss(1)
        train_model(model, criterion, optimizer, dataloader,
                    cosine_lr_scheduler)
Exemple #9
0
def main():
    net = PSPNet(num_classes=num_classes)

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0,
                               'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'iter': int(split_snapshot[3]),
                               'val_loss': float(split_snapshot[5]), 'acc': float(split_snapshot[7]),
                               'acc_cls': float(split_snapshot[9]),'mean_iu': float(split_snapshot[11]),
                               'fwavacc': float(split_snapshot[13])}
    net.cuda().train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(args['val_img_display_size']),
        standard_transforms.ToTensor()
    ])

    train_set = Retinaimages('training', joint_transform=train_joint_transform, sliding_crop=sliding_crop,
                                      transform=train_input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=2, shuffle=True)
    val_set = Retinaimages('validate', transform=val_input_transform, sliding_crop=sliding_crop,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=2, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'], nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, "_1" + '.txt'), 'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize, val_set)
Exemple #10
0
#     output_stride=16,
#     num_classes=2,
#     pretrained_backbone=None,
# )

# # BiSeNet
# model = BiSeNet(
#     backbone='resnet18',
#     num_classes=2,
#     pretrained_backbone=None,
# )

# PSPNet
model = PSPNet(
    backbone='resnet18',
    num_classes=2,
    pretrained_backbone=None,
)

# # ICNet
# model = ICNet(
#     backbone='resnet18',
#     num_classes=2,
#     pretrained_backbone=None,
# )

#------------------------------------------------------------------------------
#   Summary network
#------------------------------------------------------------------------------
model.train()
model.summary(input_shape=(3, args.input_sz, args.input_sz), device='cpu')
Exemple #11
0
def train_process(args, dt_config, dataset_class, data_transform_class):
    # input_size = (params["img_h"], params["img_w"])
    input_size = (475, 475)
    num_classes = 20

    # transforms = [
    #     OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.5),
    #     # OneOf(
    #     #     [
    #     #         MedianBlur(blur_limit=3),
    #     #         GaussianBlur(blur_limit=3),
    #     #         MotionBlur(blur_limit=3),
    #     #     ],
    #     #     p=0.1,
    #     # ),
    #     RandomGamma(gamma_limit=(80, 120), p=0.5),
    #     RandomBrightnessContrast(p=0.5),
    #     HueSaturationValue(
    #         hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=10, p=0.5
    #     ),
    #     ChannelShuffle(p=0.5),
    #     HorizontalFlip(p=0.5),
    #     Cutout(num_holes=2, max_w_size=40, max_h_size=40, p=0.5),
    #     Rotate(limit=20, p=0.5, border_mode=0),
    # ]

    data_transform = data_transform_class(num_classes=num_classes,
                                          input_size=input_size)
    train_dataset = dataset_class(
        data_path=dt_config.DATA_PATH,
        phase="train",
        transform=data_transform,
    )

    val_dataset = dataset_class(
        data_path=dt_config.DATA_PATH,
        phase="val",
        transform=data_transform,
    )

    train_data_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        drop_last=True,
    )
    val_data_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        drop_last=True,
    )
    data_loaders_dict = {"train": train_data_loader, "val": val_data_loader}
    tblogger = SummaryWriter(dt_config.LOG_PATH)

    model = PSPNet(num_classes=num_classes)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = PSPLoss()
    optimizer = torch.optim.SGD(
        [
            {
                "params": model.feature_conv.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_res_1.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_res_2.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_dilated_res_1.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_dilated_res_2.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.pyramid_pooling.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.decode_feature.parameters(),
                "lr": 1e-2
            },
            {
                "params": model.aux.parameters(),
                "lr": 1e-2
            },
        ],
        momentum=0.9,
        weight_decay=0.0001,
    )

    def _lambda_epoch(epoch):
        import math

        max_epoch = args.num_epoch
        return math.pow((1 - epoch / max_epoch), 0.9)

    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=_lambda_epoch)

    trainer = Trainer(
        model=model,
        criterion=criterion,
        metric_func=None,
        optimizer=optimizer,
        num_epochs=args.num_epoch,
        save_period=args.save_period,
        config=dt_config,
        data_loaders_dict=data_loaders_dict,
        scheduler=scheduler,
        device=device,
        dataset_name_base=train_dataset.__name__,
        batch_multiplier=args.batch_multiplier,
        logger=tblogger,
    )

    if args.snapshot and os.path.isfile(args.snapshot):
        trainer.resume_checkpoint(args.snapshot)

    with torch.autograd.set_detect_anomaly(True):
        trainer.train()

    tblogger.close()
def main():
    net = PSPNet(num_classes=num_classes,
                 input_size=train_args['input_size']).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train',
                           simul_transform=train_simul_transform,
                           transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['batch_size'],
                              num_workers=16,
                              shuffle=True)
    val_set = CityScapes('val',
                         simul_transform=val_simul_transform,
                         transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=val_args['batch_size'],
                            num_workers=16,
                            shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and (
                'ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        2 * train_args['new_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and (
                'ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        train_args['new_lr'],
        'weight_decay':
        train_args['weight_decay']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
            and not ('ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        2 * train_args['pretrained_lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
            and not ('ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        train_args['pretrained_lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=0.9,
                          nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(ckpt_path,
                                    'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr']
        optimizer.param_groups[1]['lr'] = train_args['new_lr']
        optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr']
        optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
Exemple #13
0
def main():
    model = PSPNet(num_classes=12)
    input = torch.randn(1, 3, 475, 475)
    output, output_aux = model(input)
    print(output.shape)
    print(output_aux.shape)
Exemple #14
0
class_name = class_list[0]

print('Predict {0} with epoch {1}'.format(model_name, step_num))
result_dir = './result/{0}/{1}/'.format(model_name, class_name)
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

if model_name == 'UNet':
    model = UNet(2, in_channels=3)

if model_name == 'FCN8':
    model = FCN8(2)

if model_name == 'PSPNet':
    model = PSPNet(2)

if model_name == 'UperNet':
    model = UperNet(2)

if model_name == 'CC_UNet':
    model = CC_UNet(2)

if model_name == 'A_UNet':
    model = A_UNet(2)

device = torch.device('cuda:0')

state_dict = torch.load('./checkpoints/{0}_{1}.pth'.format(
    model_name, step_num))