Exemplo n.º 1
0
    def __init__(
            self,
            feature_scale=4,  # to reduce dimensionality
            in_resolution=256,
            output_channels=3,
            is_deconv=True,
            upper_billinear=False,
            lower_billinear=False,
            in_channels=3,
            is_batchnorm=True,
            skip_background=True,
            num_joints=17,
            nb_dims=3,  # ecoding transformation
            encoderType='UNet',
            num_encoding_layers=5,
            dimension_bg=256,
            dimension_fg=256,
            dimension_3d=3 * 64,  # needs to be devidable by 3
            latent_dropout=0.3,
            shuffle_fg=True,
            shuffle_3d=True,
            from_latent_hidden_layers=0,
            n_hidden_to3Dpose=2,
            subbatch_size=4,
            implicit_rotation=False,
            nb_stage=1,  # number of U-net stacks
            output_types=[
                '3D', 'img_crop', 'shuffled_pose', 'shuffled_appearance'
            ],
            num_cameras=4,
            num_digit_caps=10,
            num_caps_out_channel=160,
            caps_masked=False):
        super(unet, self).__init__()
        self.in_resolution = in_resolution
        self.is_deconv = is_deconv
        self.in_channels = in_channels
        self.is_batchnorm = is_batchnorm
        self.feature_scale = feature_scale
        self.nb_stage = nb_stage
        self.dimension_bg = dimension_bg
        self.dimension_fg = dimension_fg
        self.dimension_3d = dimension_3d
        self.shuffle_fg = shuffle_fg
        self.shuffle_3d = shuffle_3d
        self.num_encoding_layers = num_encoding_layers
        self.output_types = output_types
        self.encoderType = encoderType
        assert dimension_3d % 3 == 0
        self.implicit_rotation = implicit_rotation
        self.num_cameras = num_cameras

        self.skip_connections = False
        self.skip_background = skip_background
        self.subbatch_size = subbatch_size
        self.latent_dropout = latent_dropout

        # filters = [64, 128, 256, 512, 1024]
        self.filters = [64, 128, 256, 512, 512, 512]  # HACK
        # self.filters = [64, 128, 728, 512, 512, 512] # HACK
        self.filters = [int(x / self.feature_scale) for x in self.filters]
        self.bottleneck_resolution = in_resolution // (2**(
            num_encoding_layers - 1))
        num_output_features = self.bottleneck_resolution**2 * self.filters[
            num_encoding_layers - 1]
        print('bottleneck_resolution', self.bottleneck_resolution,
              'num_output_features', num_output_features)

        ####################################
        ############ encoder ###############
        # self.num_digit_caps = num_digit_caps  # 20
        # self.num_caps_out_channel = num_caps_out_channel
        # self.masked = masked
        if self.encoderType == "ResNet":
            self.encoder = resnet_VNECT_3Donly.resnet50(
                pretrained=True,
                input_key='img_crop',
                output_keys=['latent_3d', '2D_heat'],
                input_width=in_resolution,
                num_classes=self.dimension_fg + self.dimension_3d,
                num_digit_caps=num_digit_caps,
                num_caps_out_channel=num_caps_out_channel,
                caps_masked=caps_masked)

        ns = 0
        setattr(
            self, 'conv_1_stage' + str(ns),
            unetConv2(self.in_channels,
                      self.filters[0],
                      self.is_batchnorm,
                      padding=1))
        setattr(self, 'pool_1_stage' + str(ns), nn.MaxPool2d(kernel_size=2))
        for li in range(
                2, num_encoding_layers
        ):  # note, first layer(li==1) is already created, last layer(li==num_encoding_layers) is created externally
            setattr(
                self, 'conv_' + str(li) + '_stage' + str(ns),
                unetConv2(self.filters[li - 2],
                          self.filters[li - 1],
                          self.is_batchnorm,
                          padding=1))
            setattr(self, 'pool_' + str(li) + '_stage' + str(ns),
                    nn.MaxPool2d(kernel_size=2))

        if from_latent_hidden_layers:
            setattr(
                self, 'conv_' + str(num_encoding_layers) + '_stage' + str(ns),
                nn.Sequential(
                    unetConv2(self.filters[num_encoding_layers - 2],
                              self.filters[num_encoding_layers - 1],
                              self.is_batchnorm,
                              padding=1), nn.MaxPool2d(kernel_size=2)))
        else:
            setattr(
                self, 'conv_' + str(num_encoding_layers) + '_stage' + str(ns),
                unetConv2(self.filters[num_encoding_layers - 2],
                          self.filters[num_encoding_layers - 1],
                          self.is_batchnorm,
                          padding=1))

        ####################################
        ############ background ###############
        if skip_background:
            setattr(
                self, 'conv_1_stage_bg' + str(ns),
                unetConv2(self.in_channels,
                          self.filters[0],
                          self.is_batchnorm,
                          padding=1))

        ###########################################################
        ############ latent transformation and pose ###############
        assert self.dimension_fg < self.filters[num_encoding_layers - 1]
        num_output_features_3d = self.bottleneck_resolution**2 * (
            self.filters[num_encoding_layers - 1] - self.dimension_fg)
        # setattr(self, 'fc_1_stage' + str(ns), Linear(num_output_features, 1024))
        setattr(self, 'fc_1_stage' + str(ns), Linear(self.dimension_3d, 128))
        setattr(self, 'fc_2_stage' + str(ns), Linear(128,
                                                     num_joints * nb_dims))

        # self.to_pose = MLP.MLP_fromLatent(d_in=self.dimension_3d, d_hidden=2048, d_out=51, n_hidden=n_hidden_to3Dpose, dropout=0.5)
        # FIXME: fix hardcode
        self.to_pose = MLP.MLP_fromLatent(d_in=256,
                                          d_hidden=2048,
                                          d_out=51,
                                          n_hidden=n_hidden_to3Dpose,
                                          dropout=0.5)

        self.to_3d = nn.Sequential(
            Linear(num_output_features, self.dimension_3d),
            Dropout(inplace=True,
                    p=self.latent_dropout)  # removing dropout degrades results
        )

        if self.implicit_rotation:
            print("WARNING: doing implicit rotation!")
            rotation_encoding_dimension = 128
            self.encode_angle = nn.Sequential(
                Linear(3 * 3, rotation_encoding_dimension // 2),
                Dropout(inplace=True, p=self.latent_dropout),
                ReLU(inplace=False),
                Linear(rotation_encoding_dimension // 2,
                       rotation_encoding_dimension),
                Dropout(inplace=True, p=self.latent_dropout),
                ReLU(inplace=False),
                Linear(rotation_encoding_dimension,
                       rotation_encoding_dimension),
            )

            self.rotate_implicitely = nn.Sequential(
                Linear(self.dimension_3d + rotation_encoding_dimension,
                       self.dimension_3d),
                Dropout(inplace=True, p=self.latent_dropout),
                ReLU(inplace=False))

        if from_latent_hidden_layers:
            hidden_layer_dimension = 1024
            if self.dimension_fg > 0:
                self.to_fg = nn.Sequential(
                    Linear(num_output_features, 256),  # HACK pooling
                    Dropout(inplace=True, p=self.latent_dropout),
                    ReLU(inplace=False),
                    Linear(256, self.dimension_fg),
                    Dropout(inplace=True, p=self.latent_dropout),
                    ReLU(inplace=False))
            self.from_latent = nn.Sequential(
                Linear(self.dimension_3d, hidden_layer_dimension),
                Dropout(inplace=True, p=self.latent_dropout),
                ReLU(inplace=False),
                Linear(hidden_layer_dimension, num_output_features_3d),
                Dropout(inplace=True, p=self.latent_dropout),
                ReLU(inplace=False))
        else:
            if self.dimension_fg > 0:
                self.to_fg = nn.Sequential(
                    Linear(num_output_features, self.dimension_fg),
                    Dropout(inplace=True, p=self.latent_dropout),
                    ReLU(inplace=False))
            self.from_latent = nn.Sequential(
                Linear(self.dimension_3d + self.dimension_fg,
                       num_output_features_3d),
                Dropout(inplace=True, p=self.latent_dropout),
                ReLU(inplace=False))
            # self.from_latent =  nn.Sequential( Linear(self.dimension_3d, num_output_features_3d),
            #                  Dropout(inplace=True, p=self.latent_dropout),
            #                  ReLU(inplace=False))

        ####################################
        ############ decoder ###############
        upper_conv = self.is_deconv and not upper_billinear
        lower_conv = self.is_deconv and not lower_billinear
        if self.skip_connections:
            for li in range(1, num_encoding_layers - 1):
                setattr(
                    self, 'upconv_' + str(li) + '_stage' + str(ns),
                    unetUp(self.filters[num_encoding_layers - li],
                           self.filters[num_encoding_layers - li - 1],
                           upper_conv,
                           padding=1))
                # setattr(self, 'upconv_2_stage' + str(ns), unetUp(self.filters[2], self.filters[1], upper_conv, padding=1))
        else:
            for li in range(1, num_encoding_layers - 1):
                setattr(
                    self, 'upconv_' + str(li) + '_stage' + str(ns),
                    unetUpNoSKip(self.filters[num_encoding_layers - li],
                                 self.filters[num_encoding_layers - li - 1],
                                 upper_conv,
                                 padding=1))
            # setattr(self, 'upconv_2_stage' + str(ns), unetUpNoSKip(self.filters[2], self.filters[1], upper_conv, padding=1))

        if self.skip_connections or self.skip_background:
            setattr(
                self,
                'upconv_' + str(num_encoding_layers - 1) + '_stage' + str(ns),
                unetUp(self.filters[1], self.filters[0], lower_conv,
                       padding=1))
        else:
            setattr(
                self,
                'upconv_' + str(num_encoding_layers - 1) + '_stage' + str(ns),
                unetUpNoSKip(self.filters[1],
                             self.filters[0],
                             lower_conv,
                             padding=1))

        setattr(self, 'final_stage' + str(ns),
                nn.Conv2d(self.filters[0], output_channels, 1))

        self.relu = ReLU(inplace=True)
        self.relu2 = ReLU(inplace=False)
        self.dropout = Dropout(inplace=True, p=0.3)