# Build model net = CPM(21) if cuda: net = net.cuda() net = nn.DataParallel(net, device_ids=device_ids) # multi-Gpu model_path = os.path.join('ckpt/model_epoch' + str(best_model)+'.pth') state_dict = torch.load(model_path) net.load_state_dict(state_dict) # **************************************** test all images **************************************** print '********* test data *********' net.eval() all_pcks = {} # {0005:[[], [],[]], 0011:[[], [],[]] } for step, (image, label_map, center_map, imgs) in enumerate(test_dataset): image = Variable(image.cuda() if cuda else image) # 4D Tensor # Batch_size * 3 * width(368) * height(368) label_map = torch.stack([label_map] * 6, dim=1) # 4D Tensor to 5D Tensor # Batch_size * 6 * 41 * 45 * 45 label_map = Variable(label_map.cuda() if cuda else label_map) # 5D Tensor center_map = Variable(center_map.cuda() if cuda else center_map) # 4D Tensor # Batch_size * width(368) * height(368) pred_6 = net(image, center_map) # 5D tensor: batch size * stages(6) * 41 * 45 * 45 # calculate pck