Beispiel #1
0
    def __init__(self, bs, size=256):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.net = unet11('unet_celeba.pth', pretrained=True).to(self.device)

        # self.net = UNet16(pretrained=True).to(self.device)        
        # self.net.load_state_dict(torch.load('unet16.pth'))
        self.net.eval()
        sample = Variable(torch.rand(bs,3,size,size).to(self.device))
        self.net(sample)
        #  = torch.jit.trace(self.net, sample)
        print('___init___')
Beispiel #2
0
def get_unet_model():
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    UNetModel = unet11(pretrained='carvana')  #'carvana'
    UNetModel.eval()
    return UNetModel.to(device)
Beispiel #3
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('-g', '--gpu', type=int, required=True)
    parser.add_argument('-c',
                        '--config',
                        type=int,
                        default=1,
                        choices=configurations.keys())
    parser.add_argument('--resume', help='Checkpoint path')
    args = parser.parse_args()

    gpu = args.gpu
    cfg = configurations[args.config]
    out = get_log_dir('unet11', args.config, cfg)
    resume = args.resume

    os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu)
    cuda = torch.cuda.is_available()

    torch.manual_seed(1337)
    if cuda:
        torch.cuda.manual_seed(1337)

    # 1. dataset

    root = osp.expanduser('~/data/datasets')
    kwargs = {'num_workers': 4, 'pin_memory': True} if cuda else {}
    #train_loader = torch.utils.data.DataLoader(
    #   torchfcn.datasets.SBDClassSeg(root, split='train', transform=True),
    #  batch_size=1, shuffle=True, **kwargs)
    train_loader = torch.utils.data.DataLoader(plaque.Plaqueseg(
        root, split='train', transform=True),
                                               batch_size=1,
                                               shuffle=True,
                                               **kwargs)
    val_loader = torch.utils.data.DataLoader(plaque.Plaqueseg(root,
                                                              split='val',
                                                              transform=True),
                                             batch_size=1,
                                             shuffle=False,
                                             **kwargs)

    # 2. model

    model = unet_models.unet11(pretrained=False)
    start_epoch = 0
    start_iteration = 0
    if resume:
        checkpoint = torch.load(resume)
        model.load_state_dict(checkpoint['model_state_dict'])
        start_epoch = checkpoint['epoch']
        start_iteration = checkpoint['iteration']
    else:
        carvana = unet_models.unet11(pretrained='carvana')
        model = carvana
    if cuda:
        model = model.cuda()

    # 3. optimizer

    optim = torch.optim.SGD(model.parameters(),
                            lr=cfg['lr'],
                            momentum=cfg['momentum'],
                            weight_decay=cfg['weight_decay'])
    if resume:
        optim.load_state_dict(checkpoint['optim_state_dict'])

    trainer = unet_trainer.Trainer(
        cuda=cuda,
        model=model,
        optimizer=optim,
        train_loader=train_loader,
        val_loader=val_loader,
        out=out,
        max_iter=cfg['max_iteration'],
        interval_validate=cfg.get('interval_validate', len(train_loader)),
    )
    trainer.epoch = start_epoch
    trainer.iteration = start_iteration
    trainer.train()
Beispiel #4
0
def get_model():
    model = unet11(pretrained='carvana')
    model.eval()
    return model.to(device)
Beispiel #5
0
def get_model():
    model = unet11(pretrained='carvana')
    model.eval()
    return model.cuda()
Beispiel #6
0
    return nn.Sequential(*features2)


=======
    features2_net = nn.Sequential(*features2)

    if opt.load_depth_path:
        features2_net.load_state_dict(t.load(opt.load_depth_path))
        print('==> load pretrained depth model from %s' % opt.load_depth_path)




    return features2_net

model = unet11(pretrained='vgg')

>>>>>>> b43e1a358b5853ffb749ac931c9cd97a6dccf862

class decom_vgg16_2stream(nn.Module):
    def __init__(self):
        # n_class includes the background
        super(decom_vgg16_2stream, self).__init__()
<<<<<<< HEAD
        self.extractor  = decom_vgg16()
        self.extractor2 = decom_vgg16_depth()
        self.NIN = nn.Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1))
    def forward(self, x,x2):
        x1_cnn1 = self.extractor(x)
        x2_cnn2 = self.extractor2(x2)       
        x_concat=t.cat((x1_cnn1,x2_cnn2),1)