Ejemplo n.º 1
0
  def test_check_jaxpr_eqn_mismatch(self):
    def f(x):
      return jnp.sin(x) + jnp.cos(x)

    def new_jaxpr():
      return make_jaxpr(f)(jnp.float32(1.)).jaxpr

    # jaxpr is:
    #
    # { lambda  ; a.
    #   let b = sin a
    #       c = cos a
    #       d = add b c
    #   in (d,) }
    #
    # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

    jaxpr = new_jaxpr()
    # int, not float!
    jaxpr.eqns[0].outvars[0].aval = make_shaped_array(jnp.int32(2))
    self.assertRaisesRegex(
        core.JaxprTypeError,
        r"Variable 'b' inconsistently typed as f32\[\], "
        r"bound as i32\[\]\n\nin equation:\n\nb:i32\[\] = sin a",
        lambda: core.check_jaxpr(jaxpr))

    jaxpr = new_jaxpr()
    jaxpr.eqns[0].outvars[0].aval = make_shaped_array(
      np.ones((2, 3), dtype=jnp.float32))
    self.assertRaisesRegex(
        core.JaxprTypeError,
        r"Variable 'b' inconsistently typed as f32\[\], "
        r"bound as f32\[2,3\]\n\nin equation:\n\nb:f32\[2,3\] = sin a",
        lambda: core.check_jaxpr(jaxpr))
Ejemplo n.º 2
0
    def test_check_jaxpr_eqn_mismatch(self):
        def f(x):
            return jnp.sin(x) + jnp.cos(x)

        def new_jaxpr():
            return make_jaxpr(f)(1.).jaxpr

        # jaxpr is:
        #
        # { lambda  ; a.
        #   let b = sin a
        #       c = cos a
        #       d = add b c
        #   in (d,) }
        #
        # NB: eqns[0].outvars[0] and eqns[2].invars[0] are both 'b'

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(2)  # int, not float!
        self.assertRaisesRegex(
            core.JaxprTypeError,
            r"Variable '.' inconsistently typed as ShapedArray(.*), "
            r"bound as ShapedArray(.*)\n\nin equation:\n\n  . = sin .",
            lambda: core.check_jaxpr(jaxpr))

        jaxpr = new_jaxpr()
        jaxpr.eqns[0].outvars[0].aval = make_shaped_array(np.ones((2, 3)))
        self.assertRaisesRegex(
            core.JaxprTypeError,
            r"Variable '.' inconsistently typed as ShapedArray(.*), "
            r"bound as ShapedArray(.*)\n\nin equation:\n\n  . = sin .",
            lambda: core.check_jaxpr(jaxpr))