コード例 #1
0
    def backward(ctx, grad_output):
        indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
        timer = ctx.timer

        input_bp, filters_bp = ops.indice_conv_backward(features,
                                                        filters,
                                                        grad_output,
                                                        indice_pairs,
                                                        indice_pair_num,
                                                        False,
                                                        algo=ctx.algo,
                                                        timer=timer)

        return input_bp, filters_bp, None, None, None, None, None
コード例 #2
0
ファイル: functional.py プロジェクト: traveller59/spconv
    def backward(ctx, grad_output):
        indice_pairs, indice_pair_num, features, filters = ctx.saved_tensors
        timer = ctx.timer
        try:
            input_bp, filters_bp = ops.indice_conv_backward(features,
                                                            filters,
                                                            grad_output,
                                                            indice_pairs,
                                                            indice_pair_num,
                                                            False,
                                                            algo=ctx.algo,
                                                            timer=timer)
        except Exception as e:
            msg = "[Exception|indice_conv_backward]"
            msg += f"feat={features.shape},w={filters.shape},pair={indice_pairs.shape},"
            msg += f"pairnum={indice_pair_num},do={grad_output.shape}"
            print(msg, file=sys.stderr)
            spconv_save_debug_data((indice_pairs, indice_pair_num))
            raise e

        return input_bp, filters_bp, None, None, None, None, None