コード例 #1
0
ファイル: train.py プロジェクト: huziling/fcn_pytorch
                                         pin_memory=True)

# test_data = data_loader.ClassSeg(root=data_path, split='test', transform=True)
# test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False, num_workers=workNmber)

print('load model....')

if model_Type == 0:
    print("Using FCN32s")
    vgg_model = models.VGGNet(model='vgg16', pretrained=True)
    fcn_model = models.FCN32s(pretrained_net=vgg_model, n_class=n_class)
    test_model = models.FCN32s(pretrained_net=vgg_model, n_class=n_class)
elif model_Type == 1:
    print("Using FCN16s")
    vgg_model = models.VGGNet(model='vgg16', pretrained=True)
    fcn_model = models.FCN16s(pretrained_net=vgg_model, n_class=n_class)
    test_model = models.FCN16s(pretrained_net=vgg_model, n_class=n_class)
elif model_Type == 2:
    print("Using FCN8s")
    vgg_model = models.VGGNet(model='vgg16', pretrained=True)
    fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class)
    test_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class)
elif model_Type == 3:
    print("Using FCN1s")
    vgg_model = models.VGGNet(model='vgg16', pretrained=True)
    fcn_model = models.FCNss(pretrained_net=vgg_model,
                             n_class=n_class,
                             Time=False,
                             Space=False)
    test_model = models.FCNss(pretrained_net=vgg_model,
                              n_class=n_class,
コード例 #2
0
def load_img(img_path):
    img = PIL.Image.open(img_path)
    img = np.array(img, dtype=np.uint8)

    return img

def transform_img(img):
    mean_bgr = np.array([104.00699, 116.66877, 122.67892])

    img_copy = img.copy()
    img_copy = img_copy[:, :, ::-1]  # RGB -> BGR
    img_copy = img_copy.astype(np.float32)
    img_copy -= mean_bgr
    img_copy = img_copy.transpose(2, 0, 1)  # C x H x W
    img_copy = torch.from_numpy(img_copy).float()

    return img_copy

img_path = os.path.join(img_dir, img_name)
img = load_img(img_path)
img_transformed = transform_img(img)[np.newaxis]
model = models.FCN32s() if fcn_type == 'fcn32' else models.FCN16s()
model_weight = torch.load(pretrained_model)
model.load_state_dict(model_weight)
with torch.no_grad():
    img_transformed = Variable(img_transformed)
score = model(img_transformed)
lbl_pred = (expit(score.data.cpu().numpy()) * 255).astype(np.uint8)[0][0]
save_path = os.path.join(save_dir, 'test_' + img_name)
utils.overlay_imp_on_img(img, lbl_pred, save_path, colormap='jet')
コード例 #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--dataset', type=str, default='gdi', help='name of dataset, gdi or massvis (default: gdi)')
    parser.add_argument('--dataset_dir', type=str, default='/path/to/datase_dirt', help='dataset directory')
    parser.add_argument('--fcn_type', type=str, help='FCN type, fcn32 or fcn16 (default: gdi)', default='fcn32',
                        choices=['fcn32', 'fcn16'])
    parser.add_argument('--overlaid_img_dir', type=str, default='/path/to/overlaid_img_dir',
                        help='output directory path for images with heatpmap overlaid onto input images')
    parser.add_argument('--pretrained_model', type=str, default='/path/to/pretrained_model',
                        help='pretrained model converted from Caffe models')
    parser.add_argument('--config', type=int, default=1, choices=configurations.keys(),
                        help='configuration for training where several hyperparameters are defined')
    parser.add_argument('--log_file', type=str, default='F:/dataset/visimportance/log', help='/path/to/log_file')
    parser.add_argument('--resume', type=str, default='',
                        help='checkpoint file to be loaded when retraining models')
    parser.add_argument('--checkpoint_dir', type=str, default='/path/to/checkpoint_dir',
                        help='checkpoint file to be saved in each epoch')
    parser.add_argument('--eval_only', action='store_true', help='evaluation only')
    parser.add_argument('--gpu', type=int, default=0, help='GPU id (default: 0)')
    args = parser.parse_args()

    utils.create_dir(os.path.join(args.overlaid_img_dir, "train"))
    utils.create_dir(os.path.join(args.overlaid_img_dir, "valid"))
    if not args.eval_only:
        utils.create_dir(args.checkpoint_dir)
    print(args)

    gpu = args.gpu
    cfg = configurations[args.config]
    log_file = args.log_file
    resume = args.resume

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()
    args.cuda = cuda
    if args.cuda:
        print("torch.backends.cudnn.version: {}".format(torch.backends.cudnn.version()))

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset

    root = os.path.expanduser(args.dataset_dir)
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    train_loader = None
    if not args.eval_only: # training + validation
        if args.dataset == 'gdi':
            dt = GDI(root, image_dir="gd_train", imp_dir="gd_imp_train", split='train', transform=True)
        else:
            dt = Massvis(root, image_dir="train", imp_dir="train_imp", split='train', transform=True)
        train_loader = torch.utils.data.DataLoader(dt, batch_size=1, shuffle=True, **kwargs)
        print("no of images in training", len(train_loader))

    if args.dataset == 'gdi': # validation
        dv = GDI(root, image_dir="gd_val", imp_dir="gd_imp_val", split='valid', transform=True)
    else:
        dv = Massvis(root, image_dir="valid", imp_dir="valid_imp", split='valid', transform=True)
    val_loader = torch.utils.data.DataLoader(dv, batch_size=1, shuffle=False, **kwargs)
    print("no of images in evaluation", len(val_loader))


    # 2. model

    model = models.FCN32s() if args.fcn_type == 'fcn32' else models.FCN16s()

    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
        if args.fcn_type == 'fcn32':
            assert checkpoint['arch'] == 'FCN32s'
        else:
            assert checkpoint['arch'] == 'FCN16s'
    else:
        if args.fcn_type in ['fcn32', 'fcn16']:
            model_weight = torch.load(args.pretrained_model)
            model.load_state_dict(model_weight)
            if not args.eval_only:
                model._initialize_weights()
        else:
            fcn32s = models.FCN32s()
            checkpoint = torch.load(args.pretrained_model)
            fcn32s.load_state_dict(checkpoint['model_state_dict'])
            model.copy_params_from_fcn32s(fcn32s)
            model._initialize_weights()

    if cuda:
        model = model.cuda()

    # 3. optimizer

    optim = torch.optim.SGD(
        [
            {'params': get_parameters(model, bias=False, fcn_type=args.fcn_type)},
            {'params': get_parameters(model, bias=True, fcn_type=args.fcn_type), 'lr': cfg['lr'] * 2, 'weight_decay': 0},
        ],
        lr=cfg['lr'],
        momentum=cfg['momentum'],
        weight_decay=cfg['weight_decay'])
    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])

    # lr_policy: step
    last_epoch = start_iteration if resume else -1
    lr_scheduler = torch.optim.lr_scheduler.StepLR(optim,  cfg['step_size'], gamma=cfg['gamma'], last_epoch=last_epoch)

    trainer = Trainer(
        cuda=cuda,
        model=model,
        optimizer=optim,
        lr_scheduler=lr_scheduler,
        train_loader=train_loader,
        val_loader=val_loader,
        checkpoint_dir=args.checkpoint_dir,
        log_file=log_file,
        max_iter=cfg['max_iteration'],
        iter_size=cfg['iter_size'],
        interval_validate=cfg.get('interval_validate', len(train_loader)) if not args.eval_only else 0,
        overlaid_img_dir=args.overlaid_img_dir,
        dataset=args.dataset,
        eval_only=args.eval_only,
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    if not args.eval_only:
        trainer.train()
    else:
        trainer.validate()
コード例 #4
0
def main():
    use_cuda = torch.cuda.is_available()
    path = os.path.expanduser(data_path)

    dataset = data_loader.ClassSeg(root=data_path, split='val', transform=True)

    if model_Type == 0:
        print("Using FCN32s")
        vgg_model = models.VGGNet(model='vgg16', pretrained=True)
        fcn_model = models.FCN32s(pretrained_net=vgg_model, n_class=n_class)

    elif model_Type == 1:
        print("Using FCN16s")
        vgg_model = models.VGGNet(model='vgg16', pretrained=True)
        fcn_model = models.FCN16s(pretrained_net=vgg_model, n_class=n_class)

    elif model_Type == 2:
        print("Using FCN8s")
        vgg_model = models.VGGNet(model='vgg16', pretrained=True)
        fcn_model = models.FCN8s(pretrained_net=vgg_model, n_class=n_class)

    elif model_Type == 3:
        print("Using FCN1s")
        vgg_model = models.VGGNet(model='vgg16', pretrained=False)
        fcn_model = models.FCNss(pretrained_net=vgg_model,
                                 n_class=n_class,
                                 Time=False,
                                 Space=False)

    elif model_Type == 4:
        print("Using FCNs")
        vgg_model = models.VGGNet(model='vgg_self', pretrained=True)
        fcn_model = models.FCNs(pretrained_net=vgg_model, n_class=n_class)

    elif model_Type == 5:
        print("Using FCNs")
        vgg_model = models.VGGNet(model='vgg16', pretrained=True)
        fcn_model = models.FCNss(pretrained_net=vgg_model, n_class=n_class)

    # fcn_model.load_state_dict(torch.load('models/model50.pth'))
    fcn_model.load_state_dict(torch.load('./models/temp.pth'))

    fcn_model.eval()

    if use_cuda:
        fcn_model.cuda()

    for i in range(len(dataset)):
        idx = i
        img, label, Image_Path = dataset[idx]
        print("deal %s" % Image_Path[-14:])

        labelImage = tools.labelToimg(label)

        if use_cuda:
            img = img.cuda()
            img = Variable(img.unsqueeze(0))
        out = fcn_model(img)  # (1, 21, 320, 320)

        net_out = out.data.max(1)[1].squeeze_(0)  # 320, 320
        if use_cuda:
            net_out = net_out.cpu()

        # outImage = tools.labelToimg(net_out)  # 将网络输出转换成图片
        # 后处理
        data = net_out.numpy()

        outImage = tools.labelToimg(torch.from_numpy(data))  # 将网络输出转换成图片
        plt.imshow(labelImage)
        plt.axis('off')
        plt.savefig('./image/%s_P.jpg' % Image_Path[-14:-4],
                    bbox_inches='tight')