Beispiel #1
0
def divide_where(dividend: ComplexNumeric,
                 divisor: Union[ComplexNumeric, IntegralNumeric],
                 *,
                 where: Optional[BooleanNumeric] = None,
                 otherwise: Optional[ComplexNumeric] = None) -> ComplexNumeric:
    """
    Returns: `jnp.where(where, dividend / divisor, otherwise)`, but without evaluating
        `dividend / divisor` when `where` is false.  This prevents some exceptions.
    """
    if where is None:
        return jnp.true_divide(dividend, divisor)
    dividend = jnp.where(where, dividend, 1.0)
    divisor = jnp.where(where, divisor, 1.0)
    quotient = jnp.true_divide(dividend, divisor)
    return jnp.where(where, quotient, otherwise)
Beispiel #2
0
def _test_fn(hyperparams: ServerHyperParams, server_state: FFGBDistillServerState, batch: Batch):
    classifier_fn = hyperparams.get_classifier_fn(server_state.classifier)
    f_x_test = classifier_fn(batch.x)
    pred = jnp.argmax(f_x_test, axis=1)
    correct = jnp.true_divide(
        jnp.sum(jnp.equal(pred, jnp.reshape(batch.y, pred.shape))),
        batch.y.shape[0])
    return correct
Beispiel #3
0
 def calc_delta_jax(areas_a, areas_b):
     # do I need bin areas or densities?
     # I guess since by definition sum(area_a) = 1, areas are needed?!
     integrand = jnp.true_divide(jnp.square(areas_a - areas_b),
                                 areas_a + areas_b)
     # nan_to_num important as divide gives nans if both 0
     delta = 0.5 * jnp.sum(jnp.nan_to_num(integrand))
     return delta
Beispiel #4
0
    def calc_rms_jax(bin_areas, bin_centers):
        """Calculate RMS of hist from value arrays.

        Must use jnp.X functions for e.g. sum(), square(), to ensure jax can differentiate it
        """
        mean = calc_mean_jax(bin_areas, bin_centers)
        # sum_sq = jnp.sum(jnp.square((bin_areas * bin_centers) - mean))
        # do E[X^2] - E[X]^2
        sum_sq = jnp.true_divide(
            jnp.sum(bin_areas * bin_centers * bin_centers),
            jnp.sum(bin_areas)) - jnp.square(mean)
        # sum_sq = jnp.sum(jnp.square((bin_areas * bin_centers) - mean))
        return jnp.sqrt(sum_sq)
Beispiel #5
0
def true_divide(x1, x2):
  if isinstance(x1, JaxArray): x1 = x1.value
  if isinstance(x2, JaxArray): x2 = x2.value
  return JaxArray(jnp.true_divide(x1, x2))
Beispiel #6
0
value_and_grad = jax.value_and_grad(loss)
# opt_def = Adam(learning_rate=hyperparams.oracle_lr)
opt_def = Momentum(learning_rate=1e-3, weight_decay=1e-4, nesterov=True)

opt = opt_def.create(target=params)


def train_op(opt, x, y):
    v, g = value_and_grad(opt.target, x, y)
    return v, opt.apply_gradient(g)


train_op = jax.jit(train_op)
for step in range(400000):
    key, subkey = random.split(key)
    index = random.randint(subkey,
                           shape=(hyperparams.oracle_batch_size, ),
                           minval=0,
                           maxval=x_train.shape[0])
    v, opt = train_op(opt, x_train[index], y_train[index])
    if step % 500 == 0:
        print("test sgd result")
        f_x_test = model.apply(opt.target, x_test)
        test_loss = v_ce(f_x_test, y_test)
        pred = jnp.argmax(f_x_test, axis=1)
        corrct = jnp.true_divide(
            jnp.sum(jnp.equal(pred, jnp.reshape(y_test, pred.shape))),
            y_test.shape[0])
        print("step %5d, test accuracy % .4f" % (step, corrct))
Beispiel #7
0
                              lambda x, name=None: np.square(x))

squared_difference = utils.copy_docstring(
    tf.math.squared_difference, lambda x, y, name=None: np.square(x - y))

subtract = utils.copy_docstring(tf.math.subtract,
                                lambda x, y, name=None: np.subtract(x, y))

tan = utils.copy_docstring(tf.math.tan, lambda x, name=None: np.tan(x))

tanh = utils.copy_docstring(tf.math.tanh, lambda x, name=None: np.tanh(x))

top_k = utils.copy_docstring(tf.math.top_k, _top_k)

truediv = utils.copy_docstring(tf.math.truediv,
                               lambda x, y, name=None: np.true_divide(x, y))

# unsorted_segment_max = utils.copy_docstring(
#     tf.math.unsorted_segment_max,
#     lambda data, segment_ids, num_segments, name=None: (
#         np.unsorted_segment_max))

# unsorted_segment_mean = utils.copy_docstring(
#     tf.math.unsorted_segment_mean,
#     lambda data, segment_ids, num_segments, name=None: (
#         np.unsorted_segment_mean))

# unsorted_segment_min = utils.copy_docstring(
#     tf.math.unsorted_segment_min,
#     lambda data, segment_ids, num_segments, name=None: (
#         np.unsorted_segment_min))