class InferenNet(nn.Module): def __init__(self, dataset, weights_file='./Models/sppe/fast_res101_320x256.pth'): super().__init__() self.pyranet = FastPose('resnet101') print('Loading pose model from {}'.format(weights_file)) sys.stdout.flush() self.pyranet.load_state_dict( torch.load(weights_file, map_location=torch.device('cpu'))) self.pyranet.eval() self.pyranet = model self.dataset = dataset def forward(self, x): out = self.pyranet(x) out = out.narrow(1, 0, 17) flip_out = self.pyranet(flip(x)) flip_out = flip_out.narrow(1, 0, 17) flip_out = flip(shuffleLR(flip_out, self.dataset)) out = (flip_out + out) / 2 return out
class InferenNet_fastRes50(nn.Module): def __init__(self, weights_file='./Models/sppe/fast_res50_256x192.pth'): super().__init__() self.pyranet = FastPose('resnet50', 17).cuda() print('Loading pose model from {}'.format(weights_file)) self.pyranet.load_state_dict(torch.load(weights_file)) self.pyranet.eval() def forward(self, x): out = self.pyranet(x) return out
class InferenNet_fast(nn.Module): def __init__(self, weights_file='./Models/sppe/fast_res101_320x256.pth'): super().__init__() self.pyranet = FastPose('resnet101').cuda() print('Loading pose model from {}'.format(weights_file)) self.pyranet.load_state_dict(torch.load(weights_file)) self.pyranet.eval() def forward(self, x): out = self.pyranet(x) out = out.narrow(1, 0, 17) return out