def create_model(): # 加载模型 filter_widths = [int(x) for x in args.architecture.split(',')] model_eval = TemporalModel(poses_valid_2d[0].shape[-2], poses_valid_2d[0].shape[-1], dataset.skeleton().num_joints(), filter_widths=filter_widths, causal=args.causal, dropout=args.dropout, channels=args.channels, dense=args.dense) receptive_field = model_eval.receptive_field() print('INFO: Receptive field: {} frames'.format(receptive_field)) pad = (receptive_field - 1) // 2 # Padding on each side if args.causal: print('INFO: Using causal convolutions') causal_shift = pad else: causal_shift = 0 model_params = 0 for parameter in model_eval.parameters(): model_params += parameter.numel() print('INFO: Trainable parameter count:', model_params) model_eval.to(device) return model_eval, causal_shift, pad
def get_pose3d_predictor(ckpt_dir, ckpt_name, filter_widths, causal=False, channels=1024): """ 加载3d关节点坐标预测器 Args: channels: ckpt_dir: ckpt_name: filter_widths: causal: Returns: pose3d_predictor """ ckpt_path = os.path.join(ckpt_dir, ckpt_name) print('Loading checkpoint', ckpt_path) checkpoint = torch.load(ckpt_path, map_location=lambda storage, loc: storage) print('This model was trained for {} epochs'.format(checkpoint['epoch'])) pose3d_predictor = TemporalModel(17, 2, 17, filter_widths=filter_widths, causal=causal, channels=channels) receptive_field = pose3d_predictor.receptive_field() print('INFO: Receptive field: {} frames'.format(receptive_field)) pose3d_predictor.load_state_dict(checkpoint['model_pos']) return pose3d_predictor.to(device).eval()