Exemple #1
0
# Build model
net = CPM(21)
if cuda:
    net = net.cuda()
    net = nn.DataParallel(net, device_ids=device_ids)  # multi-Gpu

model_path = os.path.join('ckpt/model_epoch' + str(best_model)+'.pth')
state_dict = torch.load(model_path)
net.load_state_dict(state_dict)


# **************************************** test all images ****************************************

print '********* test data *********'
net.eval()

all_pcks = {}  # {0005:[[], [],[]], 0011:[[], [],[]] }

for step, (image, label_map, center_map, imgs) in enumerate(test_dataset):
    image = Variable(image.cuda() if cuda else image)   # 4D Tensor
    # Batch_size  *  3  *  width(368)  *  height(368)
    label_map = torch.stack([label_map] * 6, dim=1)     # 4D Tensor to 5D Tensor
    # Batch_size  *   6 *   41  *  45  *  45
    label_map = Variable(label_map.cuda() if cuda else label_map)  # 5D Tensor

    center_map = Variable(center_map.cuda() if cuda else center_map)  # 4D Tensor
    # Batch_size  *  width(368) * height(368)

    pred_6 = net(image, center_map)  # 5D tensor:  batch size * stages(6) * 41 * 45 * 45
    # calculate pck