示例#1
0
class UCF101Cropper(object):
    def __init__(self, patch_shape, kernel, hyperparameters):
        self.cropper1d = Cropper(patch_shape[:1], kernel, hyperparameters, name="cropper1d")
        self.cropper3d = Cropper(patch_shape    , kernel, hyperparameters, name="cropper3d")
        self.patch_shape = patch_shape
        self.n_spatial_dims = len(patch_shape)

#        self.fc_conv = masonry.construct_cnn(
#            name="fc_conv",
#            layer_specs=[
#            ],
#            input_shape=(patch_shape[0], 1),
#            n_channels=4096,
#            batch_normalize=hyperparameters["batch_normalize_patch"])
        self.conv_conv = masonry.construct_cnn(
            name="fc_conv",
            layer_specs=[
                dict(size=(5, 1, 1), num_filters=512, pooling_size=(2, 1, 1), pooling_step=(2, 1, 1)),
                dict(size=(5, 1, 1), num_filters=512, pooling_size=(2, 1, 1), pooling_step=(2, 1, 1)),
            ],
            input_shape=patch_shape,
            n_channels=512,
            batch_normalize=hyperparameters["batch_normalize_patch"])

    def initialize(self):
        #self.fc_conv.initialize()
        self.conv_conv.initialize()

    def apply(self, image, image_shape, location, scale):
        # image is secretly two variables; conv and fc features
        fc, conv = image
        fc_shape, conv_shape = image_shape
        # (batch, 4096, 16, 1)
        fc_patch = T.shape_padright(self.cropper1d.apply(
            fc, fc_shape[:, 1:],
            location[:, 0, np.newaxis],
            scale[:, 0, np.newaxis],
        )[0])
        # (batch, 512, 16, 1, 1)
        conv_patch = self.cropper3d.apply(
            conv, conv_shape[:, 1:],
            location, scale,
        )[0]
        fc_repr = fc_patch
        #fc_repr = self.fc_conv.apply(fc_patch)
        conv_repr = self.conv_conv.apply(conv_patch)
        # global average pooling
        fc_repr = fc_repr.mean(axis=range(2, fc_repr.ndim))
        conv_repr = conv_repr.mean(axis=range(2, conv_repr.ndim))
        patch = T.concatenate([fc_repr, conv_repr], axis=1)
        return patch, 0.

    @property
    def output_shape(self):
        return (4096 + 512,)