def flatten(g, input, start_dim, end_dim): start_dim_i = sym_help._get_const(start_dim, 'i', 'start_dim') end_dim_i = sym_help._get_const(end_dim, 'i', 'end_dim') dim = input.type().dim() if end_dim_i < 0: end_dim_i = dim + end_dim_i # use ONNX's Flatten operator for cases where the output shape is 2D if start_dim_i == 1 and end_dim_i == dim - 1: if _try_get_scalar_type(input): old_type, input = _try_cast_integer_to_float(g, input) return _cast_to_type(g, g.op("Flatten", input, axis_i=start_dim_i), old_type) else: return g.op("Flatten", input, axis_i=start_dim_i) if start_dim_i == 0 and end_dim_i == dim - 2: if _try_get_scalar_type(input): old_type, input = _try_cast_integer_to_float(g, input) return _cast_to_type(g, g.op("Flatten", input, axis_i=end_dim_i + 1), old_type) else: return g.op("Flatten", input, axis_i=end_dim_i + 1) return sym_opset9.flatten(g, input, start_dim, end_dim)
def mm(g, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is # since beta = 0 scalar_type = symbolic_helper._try_get_scalar_type(self, other) if scalar_type is None: raise errors.SymbolicValueError( "mm can only operate on tensors with known types", self ) zero_constant = g.op( "Constant", value_t=torch.tensor( [0], dtype=_type_utils.JitScalarType.from_name(scalar_type).dtype() ), ) if symbolic_helper._try_get_scalar_type(self): old_type, self, other, zero_constant = _try_cast_integer_to_float( g, self, other, zero_constant ) return _cast_to_type( g, g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0), old_type, ) return g.op("Gemm", self, other, zero_constant, beta_f=0.0, alpha_f=1.0)
def mm(g, self, other): # Create a dummy C tensor. Only needed for API purposes, the value is # since beta = 0 ty = sym_help._try_get_scalar_type(self, other).lower() C = g.constant(0, [1], ty) if _try_get_scalar_type(self): old_type, self, other, C = _try_cast_integer_to_float(g, self, other, C) return _cast_to_type(g, g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0), old_type) else: return g.op("Gemm", self, other, C, beta_f=0.0, alpha_f=1.0)
def addmm(g, self, mat1, mat2, beta, alpha): if _try_get_scalar_type(self): old_type, self, mat1, mat2 = _try_cast_integer_to_float( g, self, mat1, mat2) return _cast_to_type( g, g.op( "Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha), ), old_type, ) else: return g.op( "Gemm", mat1, mat2, self, beta_f=sym_help._scalar(beta), alpha_f=sym_help._scalar(alpha), )
def prelu(g, self, weight): self_rank = sym_help._get_tensor_rank(self) if self_rank is not None and self_rank > 2: weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) if _try_get_scalar_type(self): old_type, self, weight = _try_cast_integer_to_float(g, self, weight) return _cast_to_type(g, g.op("PRelu", self, weight), old_type) else: return g.op("PRelu", self, weight)
def prelu(g, self, weight): if self.isCompleteTensor(): self_sizes = self.type().sizes() if self_sizes and len(self_sizes) > 2: weight = g.op("Unsqueeze", weight, axes_i=list(range(1, len(self_sizes) - 1))) if _try_get_scalar_type(self): old_type, self, weight = _try_cast_integer_to_float(g, self, weight) return _cast_to_type(g, g.op("PRelu", self, weight), old_type) else: return g.op("PRelu", self, weight)
def prelu(g, self, weight): self_rank = symbolic_helper._get_tensor_rank(self) weight_sizes = symbolic_helper._get_tensor_sizes(weight) if self_rank is not None and self_rank > 2: weight = g.op("Unsqueeze", weight, axes_i=list(range(1, self_rank - 1))) elif self_rank == 0 and weight_sizes == [1]: # self and weight are both scalar but weight has rank == 1, squeeze weight. weight = symbolic_helper._squeeze_helper(g, weight, [0]) if symbolic_helper._try_get_scalar_type(self): old_type, self, weight = _try_cast_integer_to_float(g, self, weight) return _cast_to_type(g, g.op("PRelu", self, weight), old_type) else: return g.op("PRelu", self, weight)
def bmm(g, self, other): if _try_get_scalar_type(self): old_type, self, other = _try_cast_integer_to_float(g, self, other) return _cast_to_type(g, g.op("MatMul", self, other), old_type) else: return g.op("MatMul", self, other)