Exemplo n.º 1
0
    def forward(ctx,
                in_feat,
                mode=GlobalPoolingMode.AUTO,
                in_coords_key=None,
                glob_coords_key=None,
                coords_manager=None):
        assert isinstance(mode, GlobalPoolingMode), \
            f"Mode must be an instance of GlobalPoolingMode, {mode}"
        if glob_coords_key is None:
            glob_coords_key = CoordsKey(in_coords_key.D)

        gpool_forward = getattr(MEB,
                                'GlobalPoolingForward' + get_postfix(in_feat))
        broadcast_forward = getattr(MEB,
                                    'BroadcastForward' + get_postfix(in_feat))
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        mean = in_feat.new()
        num_nonzero = in_feat.new()

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        mean, num_nonzero = gpool_forward(in_feat, cpp_in_coords_key,
                                          cpp_glob_coords_key,
                                          cpp_coords_manager, True, mode.value)
        # X - \mu
        centered_feat = broadcast_forward(in_feat, -mean, add,
                                          cpp_in_coords_key,
                                          cpp_glob_coords_key,
                                          cpp_coords_manager)

        # Variance = 1/N \sum (X - \mu) ** 2
        variance, num_nonzero = gpool_forward(centered_feat**2,
                                              cpp_in_coords_key,
                                              cpp_glob_coords_key,
                                              cpp_coords_manager, True,
                                              mode.value)

        # norm_feat = (X - \mu) / \sigma
        inv_std = 1 / (variance + 1e-8).sqrt()
        norm_feat = broadcast_forward(centered_feat, inv_std, multiply,
                                      cpp_in_coords_key, cpp_glob_coords_key,
                                      cpp_coords_manager)

        ctx.mode = mode
        ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key
        ctx.coords_manager = coords_manager
        # For GPU tensors, must use save_for_backward.
        ctx.save_for_backward(inv_std, norm_feat)
        return norm_feat
Exemplo n.º 2
0
    def forward(ctx,
                in_feat,
                batch_size=0,
                in_coords_key=None,
                glob_coords_key=None,
                coords_manager=None):
        if glob_coords_key is None:
            glob_coords_key = CoordsKey(in_coords_key.D)

        gpool_forward = getattr(MEB,
                                'GlobalPoolingForward' + get_postfix(in_feat))
        broadcast_forward = getattr(MEB,
                                    'BroadcastForward' + get_postfix(in_feat))
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        mean = in_feat.new()
        num_nonzero = in_feat.new()
        D = in_coords_key.D

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        gpool_forward(D, in_feat, mean, num_nonzero, cpp_in_coords_key,
                      cpp_glob_coords_key, cpp_coords_manager, batch_size,
                      True)
        # X - \mu
        centered_feat = in_feat.new()
        broadcast_forward(D, in_feat, -mean, centered_feat, add,
                          cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        # Variance = 1/N \sum (X - \mu) ** 2
        variance = in_feat.new()
        gpool_forward(D, centered_feat**2, variance, num_nonzero,
                      cpp_in_coords_key, cpp_glob_coords_key,
                      cpp_coords_manager, batch_size, True)

        # norm_feat = (X - \mu) / \sigma
        inv_std = 1 / (variance + 1e-8).sqrt()
        norm_feat = in_feat.new()
        broadcast_forward(D, centered_feat, inv_std, norm_feat, multiply,
                          cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        ctx.batch_size = batch_size
        ctx.in_coords_key, ctx.glob_coords_key = in_coords_key, glob_coords_key
        ctx.coords_manager = coords_manager
        # For GPU tensors, must use save_for_backward.
        ctx.save_for_backward(inv_std, norm_feat)
        return norm_feat
Exemplo n.º 3
0
    def backward(ctx, out_grad):
        # https://kevinzakka.github.io/2016/09/14/batch_normalization/
        batch_size = ctx.batch_size
        in_coords_key, glob_coords_key = ctx.in_coords_key, ctx.glob_coords_key
        coords_manager = ctx.coords_manager

        # To prevent the memory leakage, compute the norm again
        inv_std, norm_feat = ctx.saved_variables
        D = in_coords_key.D

        gpool_forward = getattr(MEB,
                                'GlobalPoolingForward' + get_postfix(out_grad))
        broadcast_forward = getattr(MEB,
                                    'BroadcastForward' + get_postfix(out_grad))
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        # 1/N \sum dout
        num_nonzero = out_grad.new()
        mean_dout = out_grad.new()
        gpool_forward(D, out_grad, mean_dout, num_nonzero, cpp_in_coords_key,
                      cpp_glob_coords_key, cpp_coords_manager, batch_size,
                      True)

        # 1/N \sum (dout * out)
        mean_dout_feat = out_grad.new()
        gpool_forward(D, out_grad * norm_feat, mean_dout_feat, num_nonzero,
                      cpp_in_coords_key, cpp_glob_coords_key,
                      cpp_coords_manager, batch_size, True)

        # out * 1/N \sum (dout * out)
        feat_mean_dout_feat = out_grad.new()
        broadcast_forward(D, norm_feat, mean_dout_feat, feat_mean_dout_feat,
                          multiply, cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        unnorm_din = out_grad.new()
        broadcast_forward(D, out_grad - feat_mean_dout_feat, -mean_dout,
                          unnorm_din, add, cpp_in_coords_key,
                          cpp_glob_coords_key, cpp_coords_manager)

        norm_din = out_grad.new()
        broadcast_forward(D, unnorm_din, inv_std, norm_din, multiply,
                          cpp_in_coords_key, cpp_glob_coords_key,
                          cpp_coords_manager)

        return norm_din, None, None, None, None
    def backward(ctx, out_grad):
        # https://kevinzakka.github.io/2016/09/14/batch_normalization/
        in_coords_key, glob_coords_key = ctx.in_coords_key, ctx.glob_coords_key
        coords_manager = ctx.coords_manager

        # To prevent the memory leakage, compute the norm again
        inv_std, norm_feat = ctx.saved_tensors

        gpool_forward = get_minkowski_function('GlobalPoolingForward',
                                               out_grad)
        broadcast_forward = get_minkowski_function('BroadcastForward',
                                                   out_grad)
        add = operation_type_to_int(OperationType.ADDITION)
        multiply = operation_type_to_int(OperationType.MULTIPLICATION)

        cpp_in_coords_key = in_coords_key.CPPCoordsKey
        cpp_glob_coords_key = glob_coords_key.CPPCoordsKey
        cpp_coords_manager = coords_manager.CPPCoordsManager

        # 1/N \sum dout
        mean_dout, num_nonzero = gpool_forward(out_grad, cpp_in_coords_key,
                                               cpp_glob_coords_key,
                                               cpp_coords_manager, True,
                                               ctx.mode.value)

        # 1/N \sum (dout * out)
        mean_dout_feat, num_nonzero = gpool_forward(out_grad * norm_feat,
                                                    cpp_in_coords_key,
                                                    cpp_glob_coords_key,
                                                    cpp_coords_manager, True,
                                                    ctx.mode.value)

        # out * 1/N \sum (dout * out)
        feat_mean_dout_feat = broadcast_forward(norm_feat, mean_dout_feat,
                                                multiply, cpp_in_coords_key,
                                                cpp_glob_coords_key,
                                                cpp_coords_manager)

        unnorm_din = broadcast_forward(out_grad - feat_mean_dout_feat,
                                       -mean_dout, add, cpp_in_coords_key,
                                       cpp_glob_coords_key, cpp_coords_manager)

        norm_din = broadcast_forward(unnorm_din, inv_std, multiply,
                                     cpp_in_coords_key, cpp_glob_coords_key,
                                     cpp_coords_manager)

        return norm_din, None, None, None, None