示例#1
0
    'resnet101':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=2048,
                   deep_features_size=1024,
                   backend='resnet101'),
    'resnet152':
    lambda: PSPNet(sizes=(1, 2, 3, 6),
                   psp_size=2048,
                   deep_features_size=1024,
                   backend='resnet152')
}

# Create dataloader
dataset = PanoDataset(root_dir=args.root_dir,
                      cat_list=[*args.input_cat, 'edge', 'cor'],
                      flip=False,
                      rotate=False,
                      gamma=False,
                      return_filenames=True)

# Prepare model
backend = args.backend.lower()
net = models[backend]()
net = nn.DataParallel(net).to(device)
net.load_state_dict(torch.load(args.ckpt))

# Start evaluation
test_losses = StatisticDict()
test_pano_losses = StatisticDict()
test_2d3d_losses = StatisticDict()
for ith, datas in enumerate(dataset):
    print('processed %d batches out of %d' % (ith, len(dataset)),
示例#2
0
                    type=int,
                    help='numbers of input channels')
parser.add_argument('--num_workers',
                    default=6,
                    type=int,
                    help='numbers of workers for dataloaders')
parser.add_argument('--batch_size',
                    default=4,
                    type=int,
                    help='mini-batch size')
args = parser.parse_args()
device = torch.device(args.device)

# Create dataloader
dataset = PanoDataset(root_dir=args.root_dir,
                      cat_list=[*args.input_cat, 'edge', 'cor'],
                      flip=False,
                      rotate=False)
loader = DataLoader(dataset,
                    args.batch_size,
                    shuffle=False,
                    drop_last=False,
                    num_workers=args.num_workers,
                    pin_memory=args.device.startswith('cuda'))

# Prepare model
encoder = Encoder(args.input_channels).to(device)
edg_decoder = Decoder(skip_num=2, out_planes=3).to(device)
cor_decoder = Decoder(skip_num=3, out_planes=1).to(device)
encoder.load_state_dict(torch.load('%s_encoder.pth' % args.path_prefix))
edg_decoder.load_state_dict(torch.load('%s_edg_decoder.pth' %
                                       args.path_prefix))
示例#3
0
                    default=20,
                    help='iterations frequency to display')
parser.add_argument('--save_every',
                    type=int,
                    default=5,
                    help='epochs frequency to save state_dict')
args = parser.parse_args()
device = torch.device('cpu' if args.no_cuda else 'cuda')
np.random.seed(args.seed)
torch.manual_seed(args.seed)
os.makedirs(os.path.join(args.ckpt, args.id), exist_ok=True)

# Create dataloader
dataset_train = PanoDataset(root_dir=args.root_dir_train,
                            cat_list=[*args.input_cat, 'edge', 'cor'],
                            flip=not args.no_flip,
                            rotate=not args.no_rotate,
                            gamma=args.gamma,
                            noise=args.noise)
dataset_valid = PanoDataset(root_dir=args.root_dir_valid,
                            cat_list=[*args.input_cat, 'edge', 'cor'],
                            flip=False,
                            rotate=False,
                            gamma=False)
loader_train = DataLoader(dataset_train,
                          args.batch_size_train,
                          shuffle=True,
                          drop_last=True,
                          num_workers=args.num_workers,
                          pin_memory=not args.no_cuda)
loader_valid = DataLoader(dataset_valid,
                          args.batch_size_valid,
示例#4
0
                    help='epochs frequency to save state_dict')
args = parser.parse_args()

device = torch.device(
    'cuda' if torch.cuda.is_available() and not args.no_cuda else 'cpu')
print('device:{}'.format(device))

np.random.seed(args.seed)
torch.manual_seed(args.seed)
os.makedirs(os.path.join(args.ckpt, args.id), exist_ok=True)

# Create dataloader
dataset_train = PanoDataset(root_dir=args.root_dir_train,
                            cat_list=[*args.input_cat, 'mfc'],
                            flip=not args.no_flip,
                            rotate=not args.no_rotate,
                            gamma=not args.no_gamma,
                            noise=args.noise,
                            contrast=args.contrast)
dataset_valid = PanoDataset(root_dir=args.root_dir_valid,
                            cat_list=[*args.input_cat, 'mfc'],
                            flip=False,
                            rotate=False,
                            gamma=False,
                            noise=False,
                            contrast=False)
loader_train = DataLoader(dataset_train,
                          args.batch_size_train,
                          shuffle=True,
                          drop_last=True,
                          num_workers=args.num_workers,