def _sparse_mm(context, node): inputs = _get_inputs(context, node, expected=2) x = mb.custom_torch_sparse_matmul(x=inputs[0], y=inputs[1], x_is_sparse=True, y_is_sparse=True, name=node.name) context.add(x)
def constant_pad_nd(context, node): inputs = _get_inputs(context, node, expected=3) new_pad = inputs[1].val.reshape((-1, 2))[::-1].reshape(-1).tolist() new_pad = [0] * (2 * len(inputs[0].shape) - len(new_pad)) + new_pad padded = mb.pad(x=inputs[0], pad=np.array(new_pad), mode="constant", constant_val=float(0), name=node.name) context.add(padded)
def cosine_similarity(context, node): inputs = _get_inputs(context, node, expected=4) dim = inputs[-2].val eps = inputs[-1].val xy = mb.mul(x=inputs[0], y=inputs[1]) sum_xy = mb.reduce_sum(x=xy, axes=[dim]) xx = mb.mul(x=inputs[0], y=inputs[0]) sum_xx = mb.reduce_sum(x=xx, axes=[dim]) yy = mb.mul(x=inputs[1], y=inputs[1]) sum_yy = mb.reduce_sum(x=yy, axes=[dim]) mul_sum_xy = mb.mul(x=sum_xx, y=sum_yy) div_12 = mb.maximum(x=mul_sum_xy, y=eps * eps) div_sqrt = mb.sqrt(x=div_12) cs = mb.real_div(x=sum_xy, y=div_sqrt, name=node.name) context.add(cs)
def silu(context, node): inputs = _get_inputs(context, node, expected=1) x = inputs[0] y = mb.sigmoid(x=x) z = mb.mul(x=x, y=y, name=node.name) context.add(z)
def type_as(context, node): inputs = _get_inputs(context, node) context.add(mb.cast(x=inputs[0], dtype='fp32'), node.name)