示例#1
0
    def test_pooling_3d(self, stride, pad, kernel, size, input_channels,
                        batch_size, order, op_type, engine, gc, dc):
        assume(pad < kernel)
        assume(size + pad + pad >= kernel)
        # Currently MIOpen Pooling only supports pooling with NCHW order.
        if hiputl.run_in_hip(gc, dc) and (workspace.GetHIPVersion() < 303
                                          or order == "NHWC"):
            assume(engine != "CUDNN")
        # some case here could be calculated with global pooling, but instead
        # calculated with general implementation, slower but should still
        # be correct.
        op = core.CreateOperator(
            op_type,
            ["X"],
            ["Y"],
            strides=[stride] * 3,
            kernels=[kernel] * 3,
            pads=[pad] * 6,
            order=order,
            engine=engine,
        )
        X = np.random.rand(batch_size, size, size, size,
                           input_channels).astype(np.float32)
        if order == "NCHW":
            X = utils.NHWC2NCHW(X)

        self.assertDeviceChecks(dc, op, [X], [0], threshold=0.001)
        if 'MaxPool' not in op_type:
            self.assertGradientChecks(gc, op, [X], 0, [0], threshold=0.001)
示例#2
0
    def test_global_pooling_3d(self, kernel, size, input_channels,
                               batch_size, order, op_type, engine, gc, dc):
        # Currently MIOpen Pooling only supports pooling with NCHW order.
        if hiputl.run_in_hip(gc, dc) and (workspace.GetHIPVersion() < 303 or  order == "NHWC"):
            assume(engine != "CUDNN")
        # pad and stride ignored because they will be inferred in global_pooling
        op = core.CreateOperator(
            op_type,
            ["X"],
            ["Y"],
            kernels=[kernel] * 3,
            order=order,
            global_pooling=True,
            engine=engine,
        )
        X = np.random.rand(
            batch_size, size, size, size, input_channels).astype(np.float32)
        if order == "NCHW":
            X = utils.NHWC2NCHW(X)

        self.assertDeviceChecks(dc, op, [X], [0], threshold=0.001)
        if 'MaxPool' not in op_type:
            self.assertGradientChecks(gc, op, [X], 0, [0], threshold=0.001)