コード例 #1
0
 def init_model(self, ):
     if (self.model_name == "PSPNet"):
         self.model, self.model_out_layer_names = PSPNet(
             backbone_name=self.backbone,
             input_shape=self.input_shape,
             classes=self.n_classes,
             encoder_weights=self.pretrained_encoder_weights,
             encoder_freeze=self.transfer_learning,
             training=self.training)
         try:
             # The model weights (that are considered the best) are loaded into the model.
             self.model.load_weights(
                 '/home/essys/projects/segmentation/checkpoints/' +
                 self.train_id + "/")
         except:
             print('could not find saved model')
     if (self.model_name == "Bisenet_V2"):
         self.model = BisenetV2Model(
             train_op=optimizers.Adam(self.lr),
             input_shape=self.input_shape,
             classes=self.n_classes,
             batch_size=self.batch_size,
             class_weights=self.class_weights
         )  #BisenetV2(input_shape=self.input_shape,classes=self.n_classes)
         try:
             checkpoint_dir = '/home/essys/projects/segmentation/checkpoints/' + self.train_id + "/"
             latest = tf.train.latest_checkpoint(checkpoint_dir)
             print(latest + " is found model\n\n")
             # The model weights (that are considered the best) are loaded into the model.
             self.model.model.load_weights(latest)
         except:
             print('could not find saved model')
コード例 #2
0
def get_model(criterion=None, auxiliary_loss=False, auxloss_weight=0):
  return PSPNet(
      encoder_name='dilated_resnet50',
      encoder_weights='imagenet',
      classes=19,
      auxiliary_loss=auxiliary_loss,
      auxloss_weight=auxloss_weight,
      criterion=criterion)
コード例 #3
0
ファイル: test.py プロジェクト: xmba15/human_parsing_pytorch
def test_one_image(args, dt_config, dataset_class):
    input_size = (475, 475)
    model_path = args.snapshot
    dataset_instance = dataset_class(data_path=dt_config.DATA_PATH)
    num_classes = dataset_instance.num_classes
    model = PSPNet(num_classes=num_classes)
    model.load_state_dict(torch.load(model_path)["state_dict"])
    model.eval()

    img = cv2.imread(args.image_path)
    processed_img = cv2.resize(img, input_size)
    overlay = np.copy(processed_img)
    processed_img = processed_img / 255.0
    processed_img = torch.tensor(
        processed_img.transpose(2, 0, 1)[np.newaxis, :]).float()
    if torch.cuda.is_available():
        model = model.cuda()
        processed_img = processed_img.cuda()
    output = model(processed_img)[0]
    mask = output.data.max(1)[1].cpu().numpy().reshape(475, 475)
    color_mask = np.array(dataset_instance.colors)[mask]
    alpha = args.alpha
    overlay = (((1 - alpha) * overlay) + (alpha * color_mask)).astype("uint8")
    overlay = cv2.resize(overlay, (img.shape[1], img.shape[0]))
    cv2.imwrite("result.jpg", overlay)
コード例 #4
0
ファイル: hand_seg.py プロジェクト: MLsmaller/cpp_piano
def main():
    args = parse_arguments()

    # Dataset used for training the model
    MEAN = [0.45734706, 0.43338275, 0.40058118]
    STD = [0.23965294, 0.23532275, 0.2398498]

    to_tensor = transforms.ToTensor()
    normalize = transforms.Normalize(MEAN, STD)
    num_classes = 2
    palette = [0, 0, 0, 128, 0, 128]

    # Model
    model = PSPNet(num_classes=num_classes, backbone='resnet18')
    availble_gpus = list(range(torch.cuda.device_count()))
    device = torch.device('cuda:0' if len(availble_gpus) > 0 else 'cpu')

    checkpoint = torch.load(args.model)
    if isinstance(checkpoint, dict) and 'state_dict' in checkpoint.keys():
        checkpoint = checkpoint['state_dict']
    if 'module' in list(checkpoint.keys())[0] and not isinstance(
            model, torch.nn.DataParallel):
        model = torch.nn.DataParallel(model)
    model.load_state_dict(checkpoint)
    model.to(device)
    model.eval()

    if not os.path.exists('outputs'):
        os.makedirs('outputs')

    image_files = sorted(glob(os.path.join(args.images,
                                           f'*.{args.extension}')))
    with torch.no_grad():
        tbar = tqdm(image_files, ncols=100)
        for img_file in tbar:
            image = Image.open(img_file).convert('RGB')
            image = image.resize((480, 320))
            input = normalize(to_tensor(image)).unsqueeze(0)
            print(input.size())
            t1 = time.time()
            prediction = model(input.to(device))
            prediction = prediction.squeeze(0).cpu().numpy()
            print(time.time() - t1)
            prediction = F.softmax(torch.from_numpy(prediction),
                                   dim=0).argmax(0).cpu().numpy()
            save_images(image, prediction, args.output, img_file, palette)
コード例 #5
0
ファイル: train.py プロジェクト: YOUSIKI/PyTorch-PSPNet
 def __init__(self,
              n_classes,
              psp_size=2048,
              psp_bins=(1, 2, 3, 6),
              dropout=0.1,
              backbone='resnet50',
              **kwargs):
     super().__init__()
     self.save_hyperparameters()
     self.pspnet = PSPNet(n_classes=n_classes,
                          psp_size=psp_size,
                          psp_bins=psp_bins,
                          dropout=dropout,
                          backbone=Backbone(backbone, pretrained=True))
     self.ckpts_index = 0
コード例 #6
0
def main():
    batch_size = 8

    net = PSPNet(pretrained=False, num_classes=num_classes, input_size=(512, 1024)).cuda()
    snapshot = 'epoch_48_validation_loss_5.1326_mean_iu_0.3172_lr_0.00001000.pth'
    net.load_state_dict(torch.load(os.path.join(ckpt_path, snapshot)))
    net.eval()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    transform = transforms.Compose([
        expanded_transform.FreeScale((512, 1024)),
        transforms.ToTensor(),
        transforms.Normalize(*mean_std)
    ])
    restore = transforms.Compose([
        expanded_transform.DeNormalize(*mean_std),
        transforms.ToPILImage()
    ])

    lsun_path = '/home/b3-542/LSUN'

    dataset = LSUN(lsun_path, ['tower_val', 'church_outdoor_val', 'bridge_val'], transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, num_workers=16, shuffle=True)

    if not os.path.exists(test_results_path):
        os.mkdir(test_results_path)

    for vi, data in enumerate(dataloader, 0):
        inputs, labels = data
        inputs = Variable(inputs, volatile=True).cuda()
        outputs = net(inputs)

        prediction = outputs.cpu().data.max(1)[1].squeeze_(1).numpy()

        for idx, tensor in enumerate(zip(inputs.cpu().data, prediction)):
            pil_input = restore(tensor[0])
            pil_output = colorize_mask(tensor[1])
            pil_input.save(os.path.join(test_results_path, '%d_img.png' % (vi * batch_size + idx)))
            pil_output.save(os.path.join(test_results_path, '%d_out.png' % (vi * batch_size + idx)))
            print 'save the #%d batch, %d images' % (vi + 1, idx + 1)
コード例 #7
0
ファイル: main.py プロジェクト: sunjie155633/DRSNet-1
def test():
    if args.choose_net == "Unet":
        model = my_unet.UNet(3, 1).to(device)
    if args.choose_net == "My_Unet":
        model = my_unet.My_Unet2(3, 1).to(device)
    elif args.choose_net == "Enet":
        model = enet.ENet(num_classes=13).to(device)
    elif args.choose_net == "Segnet":
        model = segnet.SegNet(3, 1).to(device)
    elif args.choose_net == "CascadNet":
        model = my_cascadenet.CascadeNet(3, 1).to(device)

    elif args.choose_net == "my_drsnet_A":
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_B":
        model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_C":
        model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_A_direct_skip":
        model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3,
                                                       out_ch=1).to(device)
    elif args.choose_net == "SEResNet":
        model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device)

    elif args.choose_net == "resnext_unet":
        model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "resnet50_unet":
        model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3,
                                                      out_ch=1).to(device)
    elif args.choose_net == "unet_res34":
        model = unet_res34.Resnet_Unet(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "dfanet":
        ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]]
        model = dfanet.DFANet(ch_cfg, 3, 1).to(device)
    elif args.choose_net == "cgnet":
        model = cgnet.Context_Guided_Network(1).to(device)
    elif args.choose_net == "lednet":
        model = lednet.Net(num_classes=1).to(device)
    elif args.choose_net == "bisenet":
        model = bisenet.BiSeNet(1, 'resnet18').to(device)
    elif args.choose_net == "espnet":
        model = espnet.ESPNet(classes=1).to(device)
    elif args.choose_net == "pspnet":
        model = pspnet.PSPNet(1).to(device)
    elif args.choose_net == "fddwnet":
        model = fddwnet.Net(classes=1).to(device)
    elif args.choose_net == "contextnet":
        model = contextnet.ContextNet(classes=1).to(device)
    elif args.choose_net == "linknet":
        model = linknet.LinkNet(classes=1).to(device)
    elif args.choose_net == "edanet":
        model = edanet.EDANet(classes=1).to(device)
    elif args.choose_net == "erfnet":
        model = erfnet.ERFNet(classes=1).to(device)
    dsize = (1, 3, 128, 192)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(model, (inputs, ), verbose=False)
    print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3)))

    model.load_state_dict(torch.load(args.weight))
    liver_dataset = LiverDataset("data/val_camvid",
                                 transform=x_transform,
                                 target_transform=y_transform)
    dataloaders = DataLoader(liver_dataset)  # batch_size默认为1
    model.eval()

    metric = SegmentationMetric(13)
    # import matplotlib.pyplot as plt
    # plt.ion()
    multiclass = 1
    mean_acc, mean_miou = [], []

    alltime = 0.0
    with torch.no_grad():
        for x, y_label in dataloaders:
            x = x.to(device)
            start = time.time()
            y = model(x)
            usingtime = time.time() - start
            alltime = alltime + usingtime

            if multiclass == 1:
                # predict输出处理:
                # https://www.cnblogs.com/ljwgis/p/12313047.html
                y = F.sigmoid(y)
                y = y.cpu()
                # y = torch.squeeze(y).numpy()
                y = torch.argmax(y.squeeze(0), dim=0).data.numpy()
                print(y.max(), y.min())
                # y_label = y_label[0]
                y_label = torch.squeeze(y_label).numpy()
            else:
                y = y.cpu()
                y = torch.squeeze(y).numpy()
                y_label = torch.squeeze(y_label).numpy()

                # img_y = y*127.5

                if args.choose_net == "Unet":
                    y = (y > 0.5)
                elif args.choose_net == "My_Unet":
                    y = (y > 0.5)
                elif args.choose_net == "Enet":
                    y = (y > 0.5)
                elif args.choose_net == "Segnet":
                    y = (y > 0.5)
                elif args.choose_net == "Scnn":
                    y = (y > 0.5)
                elif args.choose_net == "CascadNet":
                    y = (y > 0.8)

                elif args.choose_net == "my_drsnet_A":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_B":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_C":
                    y = (y > 0.5)
                elif args.choose_net == "my_drsnet_A_direct_skip":
                    y = (y > 0.5)
                elif args.choose_net == "SEResNet":
                    y = (y > 0.5)

                elif args.choose_net == "resnext_unet":
                    y = (y > 0.5)
                elif args.choose_net == "resnet50_unet":
                    y = (y > 0.5)
                elif args.choose_net == "unet_res34":
                    y = (y > 0.5)
                elif args.choose_net == "dfanet":
                    y = (y > 0.5)
                elif args.choose_net == "cgnet":
                    y = (y > 0.5)
                elif args.choose_net == "lednet":
                    y = (y > 0.5)
                elif args.choose_net == "bisenet":
                    y = (y > 0.5)
                elif args.choose_net == "pspnet":
                    y = (y > 0.5)
                elif args.choose_net == "fddwnet":
                    y = (y > 0.5)
                elif args.choose_net == "contextnet":
                    y = (y > 0.5)
                elif args.choose_net == "linknet":
                    y = (y > 0.5)
                elif args.choose_net == "edanet":
                    y = (y > 0.5)
                elif args.choose_net == "erfnet":
                    y = (y > 0.5)

            img_y = y.astype(int).squeeze()
            print(y_label.shape, img_y.shape)
            image = np.concatenate((img_y, y_label))

            y_label = y_label.astype(int)
            metric.addBatch(img_y, y_label)
            acc = metric.classPixelAccuracy()
            mIoU = metric.meanIntersectionOverUnion()
            # confusionMatrix=metric.genConfusionMatrix(img_y, y_label)
            mean_acc.append(acc[1])
            mean_miou.append(mIoU)
            # print(acc, mIoU,confusionMatrix)
            print(acc, mIoU)
            plt.imshow(image * 5)
            plt.pause(0.1)
            plt.show()
    # 计算时需封印acc和miou计算部分

    print("Took ", alltime, "seconds")
    print("Took", alltime / 638.0, "s/perimage")
    print("FPS", 1 / (alltime / 638.0))
    print("average acc:%0.6f  average miou:%0.6f" %
          (np.mean(mean_acc), np.mean(mean_miou)))
コード例 #8
0
ファイル: main.py プロジェクト: sunjie155633/DRSNet-1
def train():
    if args.choose_net == "Unet":
        model = my_unet.UNet(3, 1).to(device)
    if args.choose_net == "My_Unet":
        model = my_unet.My_Unet2(3, 1).to(device)
    elif args.choose_net == "Enet":
        model = enet.ENet(num_classes=13).to(device)
    elif args.choose_net == "Segnet":
        model = segnet.SegNet(3, 13).to(device)
    elif args.choose_net == "CascadNet":
        model = my_cascadenet.CascadeNet(3, 1).to(device)

    elif args.choose_net == "my_drsnet_A":
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_B":
        model = my_drsnet.MultiscaleSENetB(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_C":
        model = my_drsnet.MultiscaleSENetC(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "my_drsnet_A_direct_skip":
        model = my_drsnet.MultiscaleSENetA_direct_skip(in_ch=3,
                                                       out_ch=1).to(device)
    elif args.choose_net == "SEResNet":
        model = my_drsnet.SEResNet18(in_ch=3, out_ch=1).to(device)

    elif args.choose_net == "resnext_unet":
        model = resnext_unet.resnext50(in_ch=3, out_ch=1).to(device)
    elif args.choose_net == "resnet50_unet":
        model = resnet50_unet.UNetWithResnet50Encoder(in_ch=3,
                                                      out_ch=1).to(device)
    elif args.choose_net == "unet_nest":
        model = unet_nest.UNet_Nested(3, 2).to(device)
    elif args.choose_net == "unet_res34":
        model = unet_res34.Resnet_Unet(3, 1).to(device)
    elif args.choose_net == "trangle_net":
        model = mytrangle_net.trangle_net(3, 1).to(device)
    elif args.choose_net == "dfanet":
        ch_cfg = [[8, 48, 96], [240, 144, 288], [240, 144, 288]]
        model = dfanet.DFANet(ch_cfg, 3, 1).to(device)
    elif args.choose_net == "lednet":
        model = lednet.Net(num_classes=1).to(device)
    elif args.choose_net == "cgnet":
        model = cgnet.Context_Guided_Network(classes=1).to(device)
    elif args.choose_net == "pspnet":
        model = pspnet.PSPNet(1).to(device)
    elif args.choose_net == "bisenet":
        model = bisenet.BiSeNet(1, 'resnet18').to(device)
    elif args.choose_net == "espnet":
        model = espnet.ESPNet(classes=1).to(device)
    elif args.choose_net == "fddwnet":
        model = fddwnet.Net(classes=1).to(device)
    elif args.choose_net == "contextnet":
        model = contextnet.ContextNet(classes=1).to(device)
    elif args.choose_net == "linknet":
        model = linknet.LinkNet(classes=1).to(device)
    elif args.choose_net == "edanet":
        model = edanet.EDANet(classes=1).to(device)
    elif args.choose_net == "erfnet":
        model = erfnet.ERFNet(classes=1).to(device)

    from collections import OrderedDict

    loadpretrained = 0
    # 0:no loadpretrained model
    # 1:loadpretrained model to original network
    # 2:loadpretrained model to new network
    if loadpretrained == 1:
        model.load_state_dict(torch.load(args.weight))

    elif loadpretrained == 2:
        model = my_drsnet.MultiscaleSENetA(in_ch=3, out_ch=1).to(device)
        model_dict = model.state_dict()
        pretrained_dict = torch.load(args.weight)
        pretrained_dict = {
            k: v
            for k, v in pretrained_dict.items() if k in model_dict
        }
        model_dict.update(pretrained_dict)
        model.load_state_dict(model_dict)

        # model.load_state_dict(torch.load(args.weight))
        # pretrained_dict = {k: v for k, v in model.items() if k in model}  # filter out unnecessary keys
        # model.update(pretrained_dict)
        # model.load_state_dict(model)

    # 计算模型参数量和计算量FLOPs
    dsize = (1, 3, 128, 192)
    inputs = torch.randn(dsize).to(device)
    total_ops, total_params = profile(model, (inputs, ), verbose=False)
    print(" %.2f | %.2f" % (total_params / (1000**2), total_ops / (1000**3)))
    batch_size = args.batch_size

    # 加载数据集
    liver_dataset = LiverDataset("data/train_camvid/",
                                 transform=x_transform,
                                 target_transform=y_transform)
    len_img = liver_dataset.__len__()
    dataloader = DataLoader(liver_dataset,
                            batch_size=batch_size,
                            shuffle=True,
                            num_workers=24)

    # DataLoader:该接口主要用来将自定义的数据读取接口的输出或者PyTorch已有的数据读取接口的输入按照batch size封装成Tensor
    # batch_size:how many samples per minibatch to load,这里为4,数据集大小400,所以一共有100个minibatch
    # shuffle:每个epoch将数据打乱,这里epoch=10。一般在训练数据中会采用
    # num_workers:表示通过多个进程来导入数据,可以加快数据导入速度

    # 梯度下降
    # optimizer = optim.Adam(model.parameters())  # model.parameters():Returns an iterator over module parameters
    # # Observe that all parameters are being optimized

    optimizer = optim.SGD(model.parameters(),
                          lr=args.lr,
                          momentum=0.9,
                          weight_decay=0.0001)

    # 每n个epoches来一次余弦退火
    cosine_lr_scheduler = lr_scheduler.CosineAnnealingLR(
        optimizer, T_max=10 * int(len_img / batch_size), eta_min=0.00001)

    multiclass = 1
    if multiclass == 1:
        # 损失函数
        class_weights = np.array([
            0., 6.3005947, 4.31063664, 34.09234699, 50.49834979, 3.88280945,
            50.49834979, 8.91626081, 47.58477105, 29.41289083, 18.95706775,
            37.84558871, 39.3477858
        ])  #camvid
        # class_weights = weighing(dataloader, 13, c=1.02)
        class_weights = torch.from_numpy(class_weights).float().to(device)
        criterion = torch.nn.CrossEntropyLoss(weight=class_weights)
        # criterion = LovaszLossSoftmax()
        # criterion = torch.nn.MSELoss()
        train_modelmulticlasses(model, criterion, optimizer, dataloader,
                                cosine_lr_scheduler)
    else:
        # 损失函数
        # criterion = LovaszLossHinge()
        # weights=[0.2]
        # weights=torch.Tensor(weights).to(device)
        # # criterion = torch.nn.CrossEntropyLoss(weight=weights)
        criterion = torch.nn.BCELoss()
        # criterion =focal_loss.FocalLoss(1)
        train_model(model, criterion, optimizer, dataloader,
                    cosine_lr_scheduler)
コード例 #9
0
def main():
    net = PSPNet(num_classes=num_classes)

    if len(args['snapshot']) == 0:
        # net.load_state_dict(torch.load(os.path.join(ckpt_path, 'cityscapes (coarse)-psp_net', 'xx.pth')))
        curr_epoch = 1
        args['best_record'] = {'epoch': 0, 'iter': 0, 'val_loss': 1e10, 'acc': 0, 'acc_cls': 0, 'mean_iu': 0,
                               'fwavacc': 0}
    else:
        print('training resumes from ' + args['snapshot'])
        net.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, args['snapshot'])))
        split_snapshot = args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1]) + 1
        args['best_record'] = {'epoch': int(split_snapshot[1]), 'iter': int(split_snapshot[3]),
                               'val_loss': float(split_snapshot[5]), 'acc': float(split_snapshot[7]),
                               'acc_cls': float(split_snapshot[9]),'mean_iu': float(split_snapshot[11]),
                               'fwavacc': float(split_snapshot[13])}
    net.cuda().train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])

    train_joint_transform = joint_transforms.Compose([
        joint_transforms.Scale(args['longer_size']),
        joint_transforms.RandomRotate(10),
        joint_transforms.RandomHorizontallyFlip()
    ])
    sliding_crop = joint_transforms.SlidingCrop(args['crop_size'], args['stride_rate'], ignore_label)
    train_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    val_input_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = extended_transforms.MaskToTensor()
    visualize = standard_transforms.Compose([
        standard_transforms.Scale(args['val_img_display_size']),
        standard_transforms.ToTensor()
    ])

    train_set = Retinaimages('training', joint_transform=train_joint_transform, sliding_crop=sliding_crop,
                                      transform=train_input_transform, target_transform=target_transform)
    train_loader = DataLoader(train_set, batch_size=args['train_batch_size'], num_workers=2, shuffle=True)
    val_set = Retinaimages('validate', transform=val_input_transform, sliding_crop=sliding_crop,
                                    target_transform=target_transform)
    val_loader = DataLoader(val_set, batch_size=1, num_workers=2, shuffle=False)

    criterion = CrossEntropyLoss2d(size_average=True).cuda()

    optimizer = optim.SGD([
        {'params': [param for name, param in net.named_parameters() if name[-4:] == 'bias'],
         'lr': 2 * args['lr']},
        {'params': [param for name, param in net.named_parameters() if name[-4:] != 'bias'],
         'lr': args['lr'], 'weight_decay': args['weight_decay']}
    ], momentum=args['momentum'], nesterov=True)

    if len(args['snapshot']) > 0:
        optimizer.load_state_dict(torch.load(os.path.join(ckpt_path, exp_name, 'opt_' + args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * args['lr']
        optimizer.param_groups[1]['lr'] = args['lr']

    check_mkdir(ckpt_path)
    check_mkdir(os.path.join(ckpt_path, exp_name))
    open(os.path.join(ckpt_path, exp_name, "_1" + '.txt'), 'w').write(str(args) + '\n\n')

    train(train_loader, net, criterion, optimizer, curr_epoch, args, val_loader, visualize, val_set)
コード例 #10
0
#     output_stride=16,
#     num_classes=2,
#     pretrained_backbone=None,
# )

# # BiSeNet
# model = BiSeNet(
#     backbone='resnet18',
#     num_classes=2,
#     pretrained_backbone=None,
# )

# PSPNet
model = PSPNet(
    backbone='resnet18',
    num_classes=2,
    pretrained_backbone=None,
)

# # ICNet
# model = ICNet(
#     backbone='resnet18',
#     num_classes=2,
#     pretrained_backbone=None,
# )

#------------------------------------------------------------------------------
#   Summary network
#------------------------------------------------------------------------------
model.train()
model.summary(input_shape=(3, args.input_sz, args.input_sz), device='cpu')
コード例 #11
0
def train_process(args, dt_config, dataset_class, data_transform_class):
    # input_size = (params["img_h"], params["img_w"])
    input_size = (475, 475)
    num_classes = 20

    # transforms = [
    #     OneOf([IAAAdditiveGaussianNoise(), GaussNoise()], p=0.5),
    #     # OneOf(
    #     #     [
    #     #         MedianBlur(blur_limit=3),
    #     #         GaussianBlur(blur_limit=3),
    #     #         MotionBlur(blur_limit=3),
    #     #     ],
    #     #     p=0.1,
    #     # ),
    #     RandomGamma(gamma_limit=(80, 120), p=0.5),
    #     RandomBrightnessContrast(p=0.5),
    #     HueSaturationValue(
    #         hue_shift_limit=5, sat_shift_limit=20, val_shift_limit=10, p=0.5
    #     ),
    #     ChannelShuffle(p=0.5),
    #     HorizontalFlip(p=0.5),
    #     Cutout(num_holes=2, max_w_size=40, max_h_size=40, p=0.5),
    #     Rotate(limit=20, p=0.5, border_mode=0),
    # ]

    data_transform = data_transform_class(num_classes=num_classes,
                                          input_size=input_size)
    train_dataset = dataset_class(
        data_path=dt_config.DATA_PATH,
        phase="train",
        transform=data_transform,
    )

    val_dataset = dataset_class(
        data_path=dt_config.DATA_PATH,
        phase="val",
        transform=data_transform,
    )

    train_data_loader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=4,
        drop_last=True,
    )
    val_data_loader = DataLoader(
        val_dataset,
        batch_size=args.batch_size,
        shuffle=False,
        num_workers=4,
        drop_last=True,
    )
    data_loaders_dict = {"train": train_data_loader, "val": val_data_loader}
    tblogger = SummaryWriter(dt_config.LOG_PATH)

    model = PSPNet(num_classes=num_classes)

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    criterion = PSPLoss()
    optimizer = torch.optim.SGD(
        [
            {
                "params": model.feature_conv.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_res_1.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_res_2.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_dilated_res_1.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.feature_dilated_res_2.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.pyramid_pooling.parameters(),
                "lr": 1e-3
            },
            {
                "params": model.decode_feature.parameters(),
                "lr": 1e-2
            },
            {
                "params": model.aux.parameters(),
                "lr": 1e-2
            },
        ],
        momentum=0.9,
        weight_decay=0.0001,
    )

    def _lambda_epoch(epoch):
        import math

        max_epoch = args.num_epoch
        return math.pow((1 - epoch / max_epoch), 0.9)

    scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=_lambda_epoch)

    trainer = Trainer(
        model=model,
        criterion=criterion,
        metric_func=None,
        optimizer=optimizer,
        num_epochs=args.num_epoch,
        save_period=args.save_period,
        config=dt_config,
        data_loaders_dict=data_loaders_dict,
        scheduler=scheduler,
        device=device,
        dataset_name_base=train_dataset.__name__,
        batch_multiplier=args.batch_multiplier,
        logger=tblogger,
    )

    if args.snapshot and os.path.isfile(args.snapshot):
        trainer.resume_checkpoint(args.snapshot)

    with torch.autograd.set_detect_anomaly(True):
        trainer.train()

    tblogger.close()
コード例 #12
0
class Builder():
    def __init__(self,
                 model_name="Bisenet_V2",
                 optimizer_name='Adam',
                 loss_names=[
                     'iou_categorical_crossentropy',
                     'iou_categorical_crossentropy'
                 ],
                 metrics_names=['accuracy'],
                 training=True,
                 transfer_learning=False,
                 input_shape=(288, 800, 3),
                 n_classes=2):
        """
        -----------------------------------------------------------------------
        1. losses
            1. categorical_crossentropy
            2. iou_categorical_crossentropy
            3. categorical_crossentropy_weighted
        2. models
            1. PSPNet
            2. Bisenet_V2
        3. optimizers
            1. Adam
            2. SGD
        4. metrics
            1. accuracy
        5. lr_decay_policy
            1. poly
        -----------------------------------------------------------------------
        """
        self.model_name = model_name
        self.optimizer_name = optimizer_name
        self.training = training
        self.n_classes = n_classes
        self.loss_names = loss_names
        self.metrics_names = metrics_names
        self.momentum = 0.9
        self.weight_decay = 0.9
        self.input_shape = input_shape
        self.lr = [0.00004, 0.0001, 0.0005, 0.001, 0.00006][4]
        self.epochs = 200
        self.lr_decay_policy = 'poly'
        self.lr_decay_power = 0.9
        self.batch_size = 32
        self.class_weights = [[0.0025, 1.0], [0.0025, 1.0], [0.0025, 1.0],
                              [0.0025, 1.0], [0.2, 1.0]][4]  #, 1.0,1.0,1.0]
        self.backbone = 'resnet50'
        self.model_out_layer_names = None
        self.pretrained_encoder_weights = 'imagenet'
        # learning decay will be applied in each #no epochs and the shape will look like staircase
        self.lr_decay_staircase = True
        self.transfer_learning = transfer_learning
        self.train_id = 'bisnet_v2_rgb'
        self.logdir = "/home/essys/projects/segmentation/logs/" + self.train_id + "/"

        self.init()

    def init(self, ):
        #        strategy = tf.distribute.MirroredStrategy()
        #        BATCH_SIZE = self.batch_size * strategy.num_replicas_in_sync
        #        with strategy.scope():
        print('init model\n\n\n\n')
        self.init_model()
        print('init optimizer\n\n\n\n')

        if (not self.model_name == "Bisenet_V2"):
            self.init_optimizer()
            print('init losses\n\n\n\n')
            self.init_losses()
            print('init metrices\n\n\n\n')
            self.init_metrics()
            print('init weight decay\n\n\n\n')
            #        self.add_weight_decay()
            print('init callbacks\n\n\n\n')

        project_dir = "./"

        data_dir = project_dir + "data/"

        train_img_paths = np.array(
            pickle.load(open(data_dir + "train_img_paths.pkl", 'rb')))
        train_img_paths = train_img_paths
        self.n_samples = len(train_img_paths)
        train_trainId_label_paths = np.array(
            pickle.load(open(data_dir + "train_trainId_label_paths.pkl",
                             'rb')))
        train_trainId_label_paths = train_trainId_label_paths
        #        train_existance_labels = np.array(pickle.load(open(data_dir + "train_existance_label.pkl", 'rb')))
        self.train_mean_channels = pickle.load(
            open("data/mean_channels.pkl", 'rb'))
        input_mean = self.train_mean_channels  #[103.939, 116.779, 123.68] # [0, 0, 0]
        input_std = [1, 1, 1]
        ignore_label = 255
        augmenters = []
        augmenters_val = []
        #scaler = transforms.GroupRandomScale(size=(0.595, 0.621), interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST))
        cropper = transforms.GroupRandomCropKeepBottom(cropsize=(400, 250))
        rotater = transforms.GroupRandomRotation(
            degree=(-15, 15),
            interpolation=(cv2.INTER_LINEAR, cv2.INTER_NEAREST),
            padding=(input_mean, (0, )))
        normalizer = transforms.GroupNormalize(mean=(input_mean, (0, )),
                                               std=(input_std, (1, )))
        color_augmenter = transforms.GroupColorAugment()
        eraser = transforms.GroupRandomErase(mean=input_mean)

        # add augmenters by their order for train
        #    augmenters.append(scaler)
        #augmenters.append(color_augmenter)
        augmenters.append(eraser)
        augmenters.append(rotater)
        augmenters.append(cropper)
        augmenters.append(normalizer)

        self.training_generator = DataGeneratorCulane.DataIterators(
            train_img_paths,
            train_trainId_label_paths,
            batch_size=self.batch_size,
            dim=(self.input_shape[0], self.input_shape[1]),
            n_classes=self.n_classes,
            n_channels=1,
            output_of_models=(self.model_out_layer_names
                              if self.model_name == 'PSPNet' else None),
            transform=augmenters)
        if (not self.model_name == "Bisenet_V2"):
            self.init_callbacks()
            print('compile model\n\n\n\n')

            # try:

    #        self.model = multi_gpu_model(self.model, gpus=4)
    # print("Training using multiple GPUs..")
    # except:
    # print("Training using single GPU or CPU..")
        if (self.model_name == "Bisenet_V2"):
            # weighted_loss_fn = weighted_categorical_crossentropy(self.class_weights)
            # #y_true,  y_pred, y_stm, y_ge1,y_ge2, y_ge3, num_classes, batch_size, input_size, weights,weighted_loss_fn, usesoftmax=True):
            # self.loss_bisenetv2 = IOULOSS_BISENETV2_func(self.model_base.y_true, self.model_base.logits, self.model_base.logits_stm,
            #                               self.model_base.logits_ge1,self.model_base.logits_ge2,self.model_base.logits_ge3,
            #                               self.n_classes,self.batch_size,self.input_shape,self.class_weights,
            #                               weighted_loss_fn,True)
            # self.model.add_loss(self.loss_bisenetv2)
            # self.model.compile(optimizer=self.optimizer, metrics=self.metrics)
            pass
        else:
            self.model.compile(loss=self.losses,
                               optimizer=self.optimizer,
                               metrics=self.metrics,
                               run_eagerly=True)

    def demo_train(self):

        len_steps = self.n_samples / self.batch_size  #66785/self.batch_size
        if (self.model_name == "Bisenet_V2"):
            iterator = self.training_generator.get_data_iterator()
            tensorboard_writer = tf.summary.create_file_writer(
                self.logdir + "/diagnose/", flush_millis=10000)
            color_dict = {0: (0, 0, 0), 1: (0, 255, 0)}
            colors = np.array([[0, 0, 0], [0, 255, 0]])

            for epoch in range(22, self.epochs):
                for i in range(int(len_steps)):

                    step = int(epoch * len_steps) + i
                    x, y = next(iterator)
                    print(np.shape(x), np.shape(y))
                    if (len(x) < self.batch_size): continue
                    y_pred, loss = self.model.network_learn(x, y)

                    print("Step {}, Loss: {}".format(
                        self.model.train_op.iterations.numpy(), loss.numpy()))
                    if (i % 20 == 0):
                        m_iou = mean_iou(y, y_pred).numpy()

                        y_pred = np.argmax(y_pred, axis=-1)
                        _y = np.argmax(y, axis=-1)

                        accuracy = backend.sum(
                            backend.cast(backend.equal(_y, y_pred),
                                         tf.float32)) / (self.batch_size *
                                                         self.input_shape[0] *
                                                         self.input_shape[1])

                        # y = y1['softmax_out']
                        y_imgs = []
                        y_imgs_pred = []
                        x_imgs = []
                        #        n_classes = np.shape(y)[-1]
                        # print(np.unique(np.argmax(y, -1)), np.unique(y_pred1),"\n\n\n\n\n")
                        y_pred = DataGeneratorCulane.gray_to_onehot_all(
                            y_pred, color_dict)

                        for i in range(len(y)):
                            y_img = np.resize(
                                np.dot(np.reshape(y[i], (-1, self.n_classes)),
                                       colors), self.input_shape)
                            y_img_pred = np.resize(
                                np.dot(
                                    np.reshape(y_pred[i],
                                               (-1, self.n_classes)), colors),
                                self.input_shape)
                            y_imgs.append(y_img)
                            y_imgs_pred.append(y_img_pred)
                            x_imgs.append(
                                (x[i] +
                                 self.train_mean_channels).astype('uint8'))

                        y_imgs = np.array(y_imgs, dtype=np.uint8)
                        x_imgs = np.array(x_imgs)
                        y_imgs_pred = np.array(y_imgs_pred)

                        with tensorboard_writer.as_default():
                            is_written = tf.summary.image("img",
                                                          x_imgs,
                                                          step=step)
                            is_written = tf.summary.image("train/gts",
                                                          y_imgs,
                                                          step=step)
                            is_written = tf.summary.image("train/predictions1",
                                                          y_imgs_pred,
                                                          step=step)
                            tf.summary.scalar("miou", m_iou, step=step)
                            tf.summary.scalar(
                                "learning_rate",
                                self.model.train_op.learning_rate.numpy(),
                                step=step)
                            tf.summary.scalar("loss", loss.numpy(), step=step)
                            tf.summary.scalar("accuracy",
                                              accuracy.numpy(),
                                              step=step)
                            if (is_written):
                                print(' image has written to the tensorboard')
                        tensorboard_writer.flush()
                    if ((step + 1) % 500 == 0):
                        checkpoint_path = '/home/essys/projects/segmentation/checkpoints/' + self.train_id + "/" + "cp-epoch-{}-step-{}.ckpt".format(
                            epoch, step)
                        print(checkpoint_path)
                        self.model.model.save_weights(checkpoint_path)
                    else:
                        print(step)
                new_learning_rate = self.model.train_op.learning_rate.numpy(
                ) * (1 - (self.model.train_op.learning_rate.numpy() /
                          self.epochs))**2800

                backend.set_value(self.model.train_op.learning_rate,
                                  new_learning_rate)

            tensorboard_writer.close()
        else:
            self.model.fit_generator(self.training_generator,
                                     steps_per_epoch=len_steps,
                                     epochs=self.epochs,
                                     callbacks=self.callbacks)

    def init_callbacks(self, ):
        scheduler = None
        tensorboard_callback = None
        self.callbacks = []
        if (self.lr_decay_policy == 'poly'):
            scheduler = schedule(
                self.lr, self.lr_decay_power, self.epochs
            )  #PolynomialDecay(maxEpochs=100, initAlpha=0.01, power=self.lr_decay_power)
        # initializing tensorboard
#        shutil.rmtree(self.logdir+"/train/")
        os.makedirs(self.logdir, exist_ok=True)
        tensorboard_callback = TensorBoard(log_dir=self.logdir)

        if (scheduler):
            print("learning rate scheduler added\n\n")
            self.callbacks.append(LearningRateScheduler(scheduler))
        if (tensorboard_callback):
            print("tensor boarda callback added  with folder " + self.logdir +
                  "\n\n")
            self.callbacks.append(tensorboard_callback)
        # Create a callback that saves the model's weights every 5 epochs
        cp_callback = tf.keras.callbacks.ModelCheckpoint(
            filepath="/home/essys/projects/segmentation/checkpoints/" +
            self.train_id + "/cp-{epoch:04d}.ckpt",
            verbose=0,
            save_weights_only=True,
            period=50)
        self.callbacks.append(cp_callback)
        diagnoser = ModelDiagonoser(self.training_generator, self.batch_size,
                                    self.n_samples, self.logdir,
                                    self.input_shape, self.n_classes)
        self.callbacks.append(diagnoser)


#        image_callback = TensorBoardImage('tag', self.logdir, self.model.outp)

    def init_metrics(self):
        assert not len(self.metrics_names) == 0
        self.metrics = []
        for i in range(len(self.metrics_names)):
            if (self.metrics_names[i] == 'accuracy'):
                self.metrics.append('accuracy')
                # m = tf.keras.metrics.MeanIoU(num_classes=5)
                # self.metrics.append(m)
    def init_losses(self, use_output_layer_names=False):
        # if(self.model_out_layer_names is None):
        #     assert not len(self.loss_names) == 0, " number of loss functions cannot be zero"
        # else:
        #     # print(self.model_out_layer_names, self.loss_names)
        #     assert len(self.loss_names) == len(self.model_out_layer_names), "in " + self.model_name + \
        #         " number of loss fs should be equal to number of output layers of the model"
        #        if(len(self.loss_names) > 1):
        if (use_output_layer_names):
            self.losses = {}
            for i in range(len(self.model_out_layer_names)):
                self.losses[self.model_out_layer_names[i]] = self.get_loss(
                    self.loss_names[i])
        else:
            self.losses = self.get_loss(self.loss_names[0])

    def get_loss(self, name):
        if (name == 'categorical_crossentropy'):
            return 'categorical_crossentropy'
        elif (name == 'iou_categorical_crossentropy'):
            return IOULOSS(self.n_classes, self.batch_size, self.input_shape,
                           self.class_weights)
        elif (name == 'categorical_crossentropy_weighted'):
            return categorical_crossentropy_weighted(self.class_weights)

    def init_model(self, ):
        if (self.model_name == "PSPNet"):
            self.model, self.model_out_layer_names = PSPNet(
                backbone_name=self.backbone,
                input_shape=self.input_shape,
                classes=self.n_classes,
                encoder_weights=self.pretrained_encoder_weights,
                encoder_freeze=self.transfer_learning,
                training=self.training)
            try:
                # The model weights (that are considered the best) are loaded into the model.
                self.model.load_weights(
                    '/home/essys/projects/segmentation/checkpoints/' +
                    self.train_id + "/")
            except:
                print('could not find saved model')
        if (self.model_name == "Bisenet_V2"):
            self.model = BisenetV2Model(
                train_op=optimizers.Adam(self.lr),
                input_shape=self.input_shape,
                classes=self.n_classes,
                batch_size=self.batch_size,
                class_weights=self.class_weights
            )  #BisenetV2(input_shape=self.input_shape,classes=self.n_classes)
            try:
                checkpoint_dir = '/home/essys/projects/segmentation/checkpoints/' + self.train_id + "/"
                latest = tf.train.latest_checkpoint(checkpoint_dir)
                print(latest + " is found model\n\n")
                # The model weights (that are considered the best) are loaded into the model.
                self.model.model.load_weights(latest)
            except:
                print('could not find saved model')

    def add_weight_decay(self, ):
        # https://jricheimer.github.io/keras/2019/02/06/keras-hack-1/
        for layer in self.model.layers:
            if isinstance(layer, keras.layers.DepthwiseConv2D):
                layer.add_loss(
                    keras.regularizers.l2(self.weight_decay)(
                        layer.depthwise_kernel))
            elif isinstance(layer, keras.layers.Conv2D) or isinstance(
                    layer, keras.layers.Dense):
                layer.add_loss(
                    keras.regularizers.l2(self.weight_decay)(layer.kernel))
            if hasattr(layer, 'bias_regularizer') and layer.use_bias:
                layer.add_loss(
                    keras.regularizers.l2(self.weight_decay)(layer.bias))

    def init_optimizer(self, ):
        if (self.optimizer_name == "Adam"):
            self.optimizer = optimizers.Adam(lr=self.lr)
        if (self.optimizer_name == "SGD"):
            self.optimizer = optimizers.SGD(self.lr, momentum=self.momentum)
コード例 #13
0
def main():
    net = PSPNet(num_classes=num_classes,
                 input_size=train_args['input_size']).cuda()
    if len(train_args['snapshot']) == 0:
        curr_epoch = 0
    else:
        print 'training resumes from ' + train_args['snapshot']
        net.load_state_dict(
            torch.load(
                os.path.join(ckpt_path, exp_name, train_args['snapshot'])))
        split_snapshot = train_args['snapshot'].split('_')
        curr_epoch = int(split_snapshot[1])
        train_record['best_val_loss'] = float(split_snapshot[3])
        train_record['corr_mean_iu'] = float(split_snapshot[6])
        train_record['corr_epoch'] = curr_epoch

    net.train()

    mean_std = ([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
    train_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.RandomCrop(train_args['input_size']),
        simul_transforms.RandomHorizontallyFlip()
    ])
    val_simul_transform = simul_transforms.Compose([
        simul_transforms.Scale(int(train_args['input_size'][0] / 0.875)),
        simul_transforms.CenterCrop(train_args['input_size'])
    ])
    img_transform = standard_transforms.Compose([
        standard_transforms.ToTensor(),
        standard_transforms.Normalize(*mean_std)
    ])
    target_transform = standard_transforms.Compose([
        expanded_transforms.MaskToTensor(),
        expanded_transforms.ChangeLabel(ignored_label, num_classes - 1)
    ])
    restore_transform = standard_transforms.Compose([
        expanded_transforms.DeNormalize(*mean_std),
        standard_transforms.ToPILImage()
    ])

    train_set = CityScapes('train',
                           simul_transform=train_simul_transform,
                           transform=img_transform,
                           target_transform=target_transform)
    train_loader = DataLoader(train_set,
                              batch_size=train_args['batch_size'],
                              num_workers=16,
                              shuffle=True)
    val_set = CityScapes('val',
                         simul_transform=val_simul_transform,
                         transform=img_transform,
                         target_transform=target_transform)
    val_loader = DataLoader(val_set,
                            batch_size=val_args['batch_size'],
                            num_workers=16,
                            shuffle=False)

    weight = torch.ones(num_classes)
    weight[num_classes - 1] = 0
    criterion = CrossEntropyLoss2d(weight).cuda()

    # don't use weight_decay for bias
    optimizer = optim.SGD([{
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] == 'bias' and (
                'ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        2 * train_args['new_lr']
    }, {
        'params': [
            param for name, param in net.named_parameters()
            if name[-4:] != 'bias' and (
                'ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        train_args['new_lr'],
        'weight_decay':
        train_args['weight_decay']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] == 'bias'
            and not ('ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        2 * train_args['pretrained_lr']
    }, {
        'params': [
            param
            for name, param in net.named_parameters() if name[-4:] != 'bias'
            and not ('ppm' in name or 'final' in name or 'aux_logits' in name)
        ],
        'lr':
        train_args['pretrained_lr'],
        'weight_decay':
        train_args['weight_decay']
    }],
                          momentum=0.9,
                          nesterov=True)

    if len(train_args['snapshot']) > 0:
        optimizer.load_state_dict(
            torch.load(os.path.join(ckpt_path,
                                    'opt_' + train_args['snapshot'])))
        optimizer.param_groups[0]['lr'] = 2 * train_args['new_lr']
        optimizer.param_groups[1]['lr'] = train_args['new_lr']
        optimizer.param_groups[2]['lr'] = 2 * train_args['pretrained_lr']
        optimizer.param_groups[3]['lr'] = train_args['pretrained_lr']

    if not os.path.exists(ckpt_path):
        os.mkdir(ckpt_path)
    if not os.path.exists(os.path.join(ckpt_path, exp_name)):
        os.mkdir(os.path.join(ckpt_path, exp_name))

    for epoch in range(curr_epoch, train_args['epoch_num']):
        train(train_loader, net, criterion, optimizer, epoch)
        validate(val_loader, net, criterion, optimizer, epoch,
                 restore_transform)
コード例 #14
0
def main():
    model = PSPNet(num_classes=12)
    input = torch.randn(1, 3, 475, 475)
    output, output_aux = model(input)
    print(output.shape)
    print(output_aux.shape)
コード例 #15
0
ファイル: predict.py プロジェクト: leigaoyi/DR_lesion
class_name = class_list[0]

print('Predict {0} with epoch {1}'.format(model_name, step_num))
result_dir = './result/{0}/{1}/'.format(model_name, class_name)
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

if model_name == 'UNet':
    model = UNet(2, in_channels=3)

if model_name == 'FCN8':
    model = FCN8(2)

if model_name == 'PSPNet':
    model = PSPNet(2)

if model_name == 'UperNet':
    model = UperNet(2)

if model_name == 'CC_UNet':
    model = CC_UNet(2)

if model_name == 'A_UNet':
    model = A_UNet(2)

device = torch.device('cuda:0')

state_dict = torch.load('./checkpoints/{0}_{1}.pth'.format(
    model_name, step_num))