Ejemplo n.º 1
0
def get_root_net(cfg, is_train):
    backbone = ResNetBackbone(cfg.resnet_type)
    root_net = RootNet()
    if is_train:
        backbone.init_weights()
        root_net.init_weights()

    model = ResRootNet(backbone, root_net)
    return model
Ejemplo n.º 2
0
def get_pose_net(cfg, is_train, joint_num):
    
    backbone = ResNetBackbone(cfg.resnet_type)
    head_net = HeadNet(joint_num)
    if is_train:
        backbone.init_weights()
        head_net.init_weights()

    model = ResPoseNet(backbone, head_net, joint_num)
    return model
Ejemplo n.º 3
0
class BackboneNet(nn.Module):
    def __init__(self):
        super(BackboneNet, self).__init__()
        self.resnet = ResNetBackbone(cfg.resnet_type)

    def init_weights(self):
        self.resnet.init_weights()

    def forward(self, img):
        img_feat = self.resnet(img)
        return img_feat
Ejemplo n.º 4
0
class BackboneNet(nn.Module):
    def __init__(self):
        super(BackboneNet, self).__init__()
        self.resnet = ResNetBackbone(cfg.resnet_type)
        self.fc = make_linear_layers([2048, cfg.backbone_img_feat_dim])

    def init_weights(self):
        self.resnet.init_weights()

    def forward(self, img):
        img_feat = self.resnet(img)
        img_feat = F.avg_pool2d(img_feat,
                                (img_feat.shape[2], img_feat.shape[3])).view(
                                    -1, 2048)
        img_feat = self.fc(img_feat)
        return img_feat
Ejemplo n.º 5
0
def get_pose_net(cfg, is_train, joint_num):


    if cfg.acb:
        builder = ACNetBuilder(base_config=None, deploy=False, gamma_init=1)
    else:
        builder = ConvBuilder(base_config=None)

    backbone = ResNetBackbone(builder, cfg.resnet_type)
    head_net = HeadNet(joint_num)
    if is_train:
        backbone.init_weights()
        head_net.init_weights()

    model = ResPoseNet(backbone, head_net, joint_num)
    return model
Ejemplo n.º 6
0
def get_model(vertex_num, joint_num, mode):
    pose_backbone = ResNetBackbone(cfg.resnet_type)
    pose_net = PoseNet(joint_num)
    pose2feat = Pose2Feat(joint_num)
    mesh_backbone = ResNetBackbone(cfg.resnet_type)
    mesh_net = MeshNet(vertex_num)
    param_regressor = ParamRegressor(joint_num)

    if mode == 'train':
        pose_backbone.init_weights()
        pose_net.apply(init_weights)
        pose2feat.apply(init_weights)
        mesh_backbone.init_weights()
        mesh_net.apply(init_weights)
        param_regressor.apply(init_weights)
   
    model = Model(pose_backbone, pose_net, pose2feat, mesh_backbone, mesh_net, param_regressor)
    return model