def divide_where(dividend: ComplexNumeric, divisor: Union[ComplexNumeric, IntegralNumeric], *, where: Optional[BooleanNumeric] = None, otherwise: Optional[ComplexNumeric] = None) -> ComplexNumeric: """ Returns: `jnp.where(where, dividend / divisor, otherwise)`, but without evaluating `dividend / divisor` when `where` is false. This prevents some exceptions. """ if where is None: return jnp.true_divide(dividend, divisor) dividend = jnp.where(where, dividend, 1.0) divisor = jnp.where(where, divisor, 1.0) quotient = jnp.true_divide(dividend, divisor) return jnp.where(where, quotient, otherwise)
def _test_fn(hyperparams: ServerHyperParams, server_state: FFGBDistillServerState, batch: Batch): classifier_fn = hyperparams.get_classifier_fn(server_state.classifier) f_x_test = classifier_fn(batch.x) pred = jnp.argmax(f_x_test, axis=1) correct = jnp.true_divide( jnp.sum(jnp.equal(pred, jnp.reshape(batch.y, pred.shape))), batch.y.shape[0]) return correct
def calc_delta_jax(areas_a, areas_b): # do I need bin areas or densities? # I guess since by definition sum(area_a) = 1, areas are needed?! integrand = jnp.true_divide(jnp.square(areas_a - areas_b), areas_a + areas_b) # nan_to_num important as divide gives nans if both 0 delta = 0.5 * jnp.sum(jnp.nan_to_num(integrand)) return delta
def calc_rms_jax(bin_areas, bin_centers): """Calculate RMS of hist from value arrays. Must use jnp.X functions for e.g. sum(), square(), to ensure jax can differentiate it """ mean = calc_mean_jax(bin_areas, bin_centers) # sum_sq = jnp.sum(jnp.square((bin_areas * bin_centers) - mean)) # do E[X^2] - E[X]^2 sum_sq = jnp.true_divide( jnp.sum(bin_areas * bin_centers * bin_centers), jnp.sum(bin_areas)) - jnp.square(mean) # sum_sq = jnp.sum(jnp.square((bin_areas * bin_centers) - mean)) return jnp.sqrt(sum_sq)
def true_divide(x1, x2): if isinstance(x1, JaxArray): x1 = x1.value if isinstance(x2, JaxArray): x2 = x2.value return JaxArray(jnp.true_divide(x1, x2))
value_and_grad = jax.value_and_grad(loss) # opt_def = Adam(learning_rate=hyperparams.oracle_lr) opt_def = Momentum(learning_rate=1e-3, weight_decay=1e-4, nesterov=True) opt = opt_def.create(target=params) def train_op(opt, x, y): v, g = value_and_grad(opt.target, x, y) return v, opt.apply_gradient(g) train_op = jax.jit(train_op) for step in range(400000): key, subkey = random.split(key) index = random.randint(subkey, shape=(hyperparams.oracle_batch_size, ), minval=0, maxval=x_train.shape[0]) v, opt = train_op(opt, x_train[index], y_train[index]) if step % 500 == 0: print("test sgd result") f_x_test = model.apply(opt.target, x_test) test_loss = v_ce(f_x_test, y_test) pred = jnp.argmax(f_x_test, axis=1) corrct = jnp.true_divide( jnp.sum(jnp.equal(pred, jnp.reshape(y_test, pred.shape))), y_test.shape[0]) print("step %5d, test accuracy % .4f" % (step, corrct))
lambda x, name=None: np.square(x)) squared_difference = utils.copy_docstring( tf.math.squared_difference, lambda x, y, name=None: np.square(x - y)) subtract = utils.copy_docstring(tf.math.subtract, lambda x, y, name=None: np.subtract(x, y)) tan = utils.copy_docstring(tf.math.tan, lambda x, name=None: np.tan(x)) tanh = utils.copy_docstring(tf.math.tanh, lambda x, name=None: np.tanh(x)) top_k = utils.copy_docstring(tf.math.top_k, _top_k) truediv = utils.copy_docstring(tf.math.truediv, lambda x, y, name=None: np.true_divide(x, y)) # unsorted_segment_max = utils.copy_docstring( # tf.math.unsorted_segment_max, # lambda data, segment_ids, num_segments, name=None: ( # np.unsorted_segment_max)) # unsorted_segment_mean = utils.copy_docstring( # tf.math.unsorted_segment_mean, # lambda data, segment_ids, num_segments, name=None: ( # np.unsorted_segment_mean)) # unsorted_segment_min = utils.copy_docstring( # tf.math.unsorted_segment_min, # lambda data, segment_ids, num_segments, name=None: ( # np.unsorted_segment_min))