def test_minmax_grad(tag): """ test_minmax_grad """ fns = FnDict() min_grad = G.MinimumGrad() @fns def before_11(x, y, dout): return tuple_getitem(min_grad(x, y, dout), 0) @fns def before_12(x, y, dout): return tuple_getitem(min_grad(x, y, dout), 1) @fns def before_2(x, y, dout): a = min_grad(x, y, dout) return tuple_getitem(a, 0), tuple_getitem(a, 1) max_grad = G.MaximumGrad() @fns def before_31(x, y, dout): return tuple_getitem(max_grad(x, y, dout), 0) @fns def before_32(x, y, dout): return tuple_getitem(max_grad(x, y, dout), 1) @fns def before_4(x, y, dout): a = max_grad(x, y, dout) return tuple_getitem(a, 0), tuple_getitem(a, 1) return fns[tag]
def __init__(self): super(MinmumGradNet, self).__init__() self.minimum_grad = G.MinimumGrad()
'desc_inputs': [[2, 3, 3, 5], [3, 3, 5]], 'desc_bprop': [[2, 3, 3, 5]]}), ('Maximum', { 'block': P.Maximum(), 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5]], 'desc_bprop': [[2, 3, 3, 5]]}), ('Maximum_0', { 'block': P.Maximum(), 'desc_inputs': [[3, 5], [2, 3, 3, 5]], 'desc_bprop': [[2, 3, 3, 5]]}), ('MaximumGrad', { 'block': G.MaximumGrad(), 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]], 'skip': ['backward']}), ('MinimumGrad', { 'block': G.MinimumGrad(), 'desc_inputs': [[2, 3, 3, 5], [2, 3, 3, 5], [2, 3, 3, 5]], 'skip': ['backward']}), ('StridedSlice', { 'block': P.StridedSlice(), 'desc_const': [(0, 1, 2, 1), (2, 3, 3, 4), (1, 1, 1, 1)], 'desc_inputs': [[2, 3, 3, 5]], 'desc_bprop': [[2, 2, 1, 3]]}), ('Slice_1', { 'block': P.Slice(), 'desc_const': [(0, 1, 2, 1), (1, 1, 1, 2)], 'desc_inputs': [[2, 3, 3, 5]], 'desc_bprop': [[1, 1, 1, 2]]}),