def verify_expand_like(in_shape, out_shape, axis, exclude): x = sym.Variable("x") y = sym.Variable("y") z = sym.expand_like(x, y, axis=axis, exclude=exclude) def forward(x, y): odim = len(out_shape) real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = sorted(real_axis) if exclude: real_axis = list(set(range(odim)) - set(real_axis)) for i in real_axis: x = np.expand_dims(x, i).astype(x.dtype) for i in real_axis: x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype) return x def backward(head_grads, x, y): odim = len(out_shape) real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = sorted(real_axis) if exclude: real_axis = list(set(range(odim)) - set(real_axis)) return [np.sum(head_grads, axis=tuple(real_axis)), np.zeros_like(y)] dtype = "float32" inputs = [('x', in_shape, x), ('y', out_shape, y)] helper(z, inputs, dtype, forward, backward, need_input=False)
def verify_expand_like(in_shape, out_shape, axis, exclude): x = sym.Variable("x") y = sym.Variable("y") z = sym.expand_like(x, y, axis=axis, exclude=exclude) def forward(x, y): odim = len(out_shape) if len(x.shape) == len(y.shape): return np.broadcast_to(x, y.shape) if x.shape == (1, ) and len(y.shape) == odim: x = np.reshape(x, ()) real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = sorted(real_axis) if exclude: real_axis = list(set(range(odim)) - set(real_axis)) for i in real_axis: x = np.expand_dims(x, i).astype(x.dtype) for i in real_axis: x = np.concatenate([x] * out_shape[i], axis=i).astype(x.dtype) return x def backward(head_grads, x, y): odim = len(out_shape) keepdims = len(x.shape) == len(y.shape) if x.shape == (1, ) and len(y.shape) == odim: x = np.reshape(x, ()) real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = sorted(real_axis) if exclude: real_axis = list(set(range(odim)) - set(real_axis)) return [ np.sum(head_grads, axis=tuple(real_axis), keepdims=keepdims), np.zeros_like(y) ] shape = {'x': in_shape, 'y': out_shape} check_function(z, forward, backward, shape=shape)
def verify_expand_like(in_shape, out_shape, axis, exclude): x = sym.Variable("x") y = sym.Variable("y") z = sym.expand_like(x, y, axis=axis, exclude=exclude) def forward(x, y): odim = len(out_shape) if len(x.shape) == len(y.shape): return np.broadcast_to(x, y.shape) if x.shape == (1,) and len(y.shape) == odim: x = np.reshape(x, ()) real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = sorted(real_axis) if exclude: real_axis = list(set(range(odim)) - set(real_axis)) for i in real_axis: x = np.expand_dims(x, i).astype(x.dtype) for i in real_axis: x = np.concatenate([x]*out_shape[i], axis=i).astype(x.dtype) return x def backward(head_grads, x, y): odim = len(out_shape) keepdims = len(x.shape) == len(y.shape) if x.shape == (1,) and len(y.shape) == odim: x = np.reshape(x, ()) real_axis = [i if i >= 0 else i + odim for i in axis] real_axis = sorted(real_axis) if exclude: real_axis = list(set(range(odim)) - set(real_axis)) return [np.sum(head_grads, axis=tuple(real_axis), keepdims=keepdims), np.zeros_like(y)] shape = {'x': in_shape, 'y': out_shape} check_function(z, forward, backward, shape=shape)