Example #1
0
    def call(self, space, args_w):
        if len(args_w) > 2:
            [w_lhs, w_rhs, w_out] = args_w
        else:
            [w_lhs, w_rhs] = args_w
            w_out = None
        w_lhs = convert_to_array(space, w_lhs)
        w_rhs = convert_to_array(space, w_rhs)
        w_ldtype = w_lhs.get_dtype()
        w_rdtype = w_rhs.get_dtype()
        if w_ldtype.is_str() and w_rdtype.is_str() and \
                self.comparison_func:
            pass
        elif (w_ldtype.is_str() or w_rdtype.is_str()) and \
                self.comparison_func and w_out is None:
            return space.wrap(False)
        elif w_ldtype.is_flexible() or w_rdtype.is_flexible():
            if self.comparison_func:
                if self.name == 'equal' or self.name == 'not_equal':
                    res = w_ldtype.eq(space, w_rdtype)
                    if not res:
                        return space.wrap(self.name == 'not_equal')
                else:
                    return space.w_NotImplemented
            else:
                raise oefmt(space.w_TypeError,
                            'unsupported operand dtypes %s and %s for "%s"',
                            w_rdtype.get_name(), w_ldtype.get_name(),
                            self.name)

        if self.are_common_types(w_ldtype, w_rdtype):
            if not w_lhs.is_scalar() and w_rhs.is_scalar():
                w_rdtype = w_ldtype
            elif w_lhs.is_scalar() and not w_rhs.is_scalar():
                w_ldtype = w_rdtype
        calc_dtype = find_binop_result_dtype(space,
            w_ldtype, w_rdtype,
            promote_to_float=self.promote_to_float,
            promote_bools=self.promote_bools)
        if (self.int_only and (not w_ldtype.is_int() or
                               not w_rdtype.is_int() or
                               not calc_dtype.is_int()) or
                not self.allow_bool and (w_ldtype.is_bool() or
                                         w_rdtype.is_bool()) or
                not self.allow_complex and (w_ldtype.is_complex() or
                                            w_rdtype.is_complex())):
            raise oefmt(space.w_TypeError,
                "ufunc '%s' not supported for the input types", self.name)
        if space.is_none(w_out):
            out = None
        elif not isinstance(w_out, W_NDimArray):
            raise oefmt(space.w_TypeError, 'output must be an array')
        else:
            out = w_out
            calc_dtype = out.get_dtype()
        if self.comparison_func:
            res_dtype = descriptor.get_dtype_cache(space).w_booldtype
        else:
            res_dtype = calc_dtype
        if w_lhs.is_scalar() and w_rhs.is_scalar():
            arr = self.func(calc_dtype,
                w_lhs.get_scalar_value().convert_to(space, calc_dtype),
                w_rhs.get_scalar_value().convert_to(space, calc_dtype)
            )
            if isinstance(out, W_NDimArray):
                if out.is_scalar():
                    out.set_scalar_value(arr)
                else:
                    out.fill(space, arr)
            else:
                out = arr
            return out
        new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
        new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
        return loop.call2(space, new_shape, self.func, calc_dtype,
                          res_dtype, w_lhs, w_rhs, out)
Example #2
0
    def call(self, space, args_w):
        if len(args_w) > 2:
            [w_lhs, w_rhs, w_out] = args_w
        else:
            [w_lhs, w_rhs] = args_w
            w_out = None
        w_lhs = convert_to_array(space, w_lhs)
        w_rhs = convert_to_array(space, w_rhs)
        w_ldtype = w_lhs.get_dtype()
        w_rdtype = w_rhs.get_dtype()
        if w_ldtype.is_str_type() and w_rdtype.is_str_type() and \
           self.comparison_func:
            pass
        elif (w_ldtype.is_str_type() or w_rdtype.is_str_type()) and \
            self.comparison_func and w_out is None:
            return space.wrap(False)
        elif (w_ldtype.is_flexible_type() or \
                w_rdtype.is_flexible_type()):
            raise OperationError(space.w_TypeError, space.wrap(
                 'unsupported operand dtypes %s and %s for "%s"' % \
                 (w_rdtype.get_name(), w_ldtype.get_name(),
                  self.name)))

        if self.are_common_types(w_ldtype, w_rdtype):
            if not w_lhs.is_scalar() and w_rhs.is_scalar():
                w_rdtype = w_ldtype
            elif w_lhs.is_scalar() and not w_rhs.is_scalar():
                w_ldtype = w_rdtype
        if (self.int_only and (not w_ldtype.is_int_type() or not w_rdtype.is_int_type()) or
                not self.allow_bool and (w_ldtype.is_bool_type() or w_rdtype.is_bool_type()) or
                not self.allow_complex and (w_ldtype.is_complex_type() or w_rdtype.is_complex_type())):
            raise OperationError(space.w_TypeError, space.wrap("Unsupported types"))
        calc_dtype = find_binop_result_dtype(space,
            w_ldtype, w_rdtype,
            promote_to_float=self.promote_to_float,
            promote_bools=self.promote_bools)
        if space.is_none(w_out):
            out = None
        elif not isinstance(w_out, W_NDimArray):
            raise OperationError(space.w_TypeError, space.wrap(
                    'output must be an array'))
        else:
            out = w_out
            calc_dtype = out.get_dtype()
        if self.comparison_func:
            res_dtype = interp_dtype.get_dtype_cache(space).w_booldtype
        else:
            res_dtype = calc_dtype
        if w_lhs.is_scalar() and w_rhs.is_scalar():
            arr = self.func(calc_dtype,
                w_lhs.get_scalar_value().convert_to(calc_dtype),
                w_rhs.get_scalar_value().convert_to(calc_dtype)
            )
            if isinstance(out, W_NDimArray):
                if out.is_scalar():
                    out.set_scalar_value(arr)
                else:
                    out.fill(arr)
            else:
                out = arr
            return out
        new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
        new_shape = shape_agreement(space, new_shape, out, broadcast_down=False)
        return loop.call2(space, new_shape, self.func, calc_dtype,
                          res_dtype, w_lhs, w_rhs, out)
Example #3
0
    def call(self, space, args_w):
        if len(args_w) > 2:
            [w_lhs, w_rhs, w_out] = args_w
        else:
            [w_lhs, w_rhs] = args_w
            w_out = None
        w_lhs = convert_to_array(space, w_lhs)
        w_rhs = convert_to_array(space, w_rhs)
        w_ldtype = w_lhs.get_dtype()
        w_rdtype = w_rhs.get_dtype()
        if w_ldtype.is_str_type() and w_rdtype.is_str_type() and \
           self.comparison_func:
            pass
        elif (w_ldtype.is_str_type() or w_rdtype.is_str_type()) and \
            self.comparison_func and w_out is None:
            return space.wrap(False)
        elif (w_ldtype.is_flexible_type() or \
                w_rdtype.is_flexible_type()):
            raise OperationError(space.w_TypeError, space.wrap(
                 'unsupported operand dtypes %s and %s for "%s"' % \
                 (w_rdtype.get_name(), w_ldtype.get_name(),
                  self.name)))

        if self.are_common_types(w_ldtype, w_rdtype):
            if not w_lhs.is_scalar() and w_rhs.is_scalar():
                w_rdtype = w_ldtype
            elif w_lhs.is_scalar() and not w_rhs.is_scalar():
                w_ldtype = w_rdtype
        if (self.int_only and
            (not w_ldtype.is_int_type() or not w_rdtype.is_int_type())
                or not self.allow_bool and
            (w_ldtype.is_bool_type() or w_rdtype.is_bool_type())
                or not self.allow_complex and
            (w_ldtype.is_complex_type() or w_rdtype.is_complex_type())):
            raise OperationError(space.w_TypeError,
                                 space.wrap("Unsupported types"))
        calc_dtype = find_binop_result_dtype(
            space,
            w_ldtype,
            w_rdtype,
            promote_to_float=self.promote_to_float,
            promote_bools=self.promote_bools)
        if space.is_none(w_out):
            out = None
        elif not isinstance(w_out, W_NDimArray):
            raise OperationError(space.w_TypeError,
                                 space.wrap('output must be an array'))
        else:
            out = w_out
            calc_dtype = out.get_dtype()
        if self.comparison_func:
            res_dtype = interp_dtype.get_dtype_cache(space).w_booldtype
        else:
            res_dtype = calc_dtype
        if w_lhs.is_scalar() and w_rhs.is_scalar():
            arr = self.func(calc_dtype,
                            w_lhs.get_scalar_value().convert_to(calc_dtype),
                            w_rhs.get_scalar_value().convert_to(calc_dtype))
            if isinstance(out, W_NDimArray):
                if out.is_scalar():
                    out.set_scalar_value(arr)
                else:
                    out.fill(arr)
            else:
                out = arr
            return out
        new_shape = shape_agreement(space, w_lhs.get_shape(), w_rhs)
        new_shape = shape_agreement(space,
                                    new_shape,
                                    out,
                                    broadcast_down=False)
        return loop.call2(space, new_shape, self.func, calc_dtype, res_dtype,
                          w_lhs, w_rhs, out)