コード例 #1
0
ファイル: test_slice_v2.py プロジェクト: zzk0/oneflow
 def test_slice_update_grad(test_case):
     input = np.random.rand(2, 7)
     update = input[:, 1:4]
     update = np.random.rand(*update.shape)
     update_diff = np.ones(update.shape)
     input_diff = np.ones(input.shape)
     input_diff[:, 1:4] = 0
     output = np.copy(input)
     output[:, 1:4] = update
     slice_tup_list = [(None, None, None), (1, 4, None)]
     arg_dict = collections.OrderedDict()
     arg_dict["dtype"] = [flow.float32, flow.float64]
     arg_dict["device_tag"] = ["cpu", "gpu"]
     arg_dict["verbose"] = [False]
     for kwarg in test_util.GenArgDict(arg_dict):
         _test_slice_update_grad(
             test_case,
             input,
             update,
             slice_tup_list,
             output,
             input_diff,
             update_diff,
             **kwarg
         )
コード例 #2
0
    def test_slice_with_grad(test_case):
        input = np.random.rand(2, 5, 4)
        slice_tup_list = [(None, None, None), (2, -2, None)]
        output = input[:, 2:-2, :]
        diff = np.zeros(input.shape, dtype=input.dtype)
        diff[:, 2:-2, :] = 1

        arg_dict = collections.OrderedDict()
        arg_dict["dtype"] = [flow.float32, flow.float64]
        arg_dict["device_tag"] = ["cpu", "gpu"]
        arg_dict["verbose"] = [False]
        for kwarg in test_util.GenArgDict(arg_dict):
            _test_slice_with_grad(test_case, input, slice_tup_list, output,
                                  diff, **kwarg)
コード例 #3
0
    def test_slice_update(test_case):
        input = np.random.rand(10, 5, 4)
        update = input[5:, :-1, ::2]
        update = np.random.rand(*update.shape)
        output = np.copy(input)
        output[5:, :-1, ::2] = update
        slice_tup_list = [(5, None, None), (None, -1, None), (None, None, 2)]

        arg_dict = collections.OrderedDict()
        arg_dict["dtype"] = [flow.float32, flow.float64]
        arg_dict["device_tag"] = ["cpu", "gpu"]
        arg_dict["verbose"] = [False]
        for kwarg in test_util.GenArgDict(arg_dict):
            _test_slice_update(test_case, input, update, slice_tup_list,
                               output, **kwarg)
コード例 #4
0
ファイル: test_slice_v2.py プロジェクト: zzk0/oneflow
 def test_slice_base(test_case):
     input = np.random.rand(10)
     slice_args = [[(1, 7, 2)]]
     outputs = [input[1:7:2]]
     arg_dict = collections.OrderedDict()
     arg_dict["dtype"] = [
         flow.uint8,
         flow.int8,
         flow.int32,
         flow.int64,
         flow.float32,
         flow.float64,
     ]
     arg_dict["device_tag"] = ["cpu", "gpu"]
     for kwarg in test_util.GenArgDict(arg_dict):
         _test_slice(test_case, input, slice_args, outputs, **kwarg)
コード例 #5
0
ファイル: test_slice_v2.py プロジェクト: zzk0/oneflow
 def test_slice_dynamic_base(test_case):
     input = np.random.rand(2, 4, 4)
     slice_args = [[(None, None, None), (1, None, None)]]
     outputs = [input[:, 1:, :]]
     arg_dict = collections.OrderedDict()
     arg_dict["dtype"] = [
         flow.uint8,
         flow.int8,
         flow.int32,
         flow.int64,
         flow.float32,
         flow.float64,
     ]
     arg_dict["device_tag"] = ["cpu", "gpu"]
     for kwarg in test_util.GenArgDict(arg_dict):
         _test_slice_dynamic(
             test_case, input, slice_args, outputs, static_shape=(2, 5, 5), **kwarg
         )