def adam2_old(params, cost_or_grads, lr=3e-4, mom1=0.9, mom2=0.999, epsilon=1e-8): updates = [] if type(cost_or_grads) is not list: gs = tf.gradients(cost_or_grads, params) else: gs = cost_or_grads # all-reduce grads1 = [Z.allreduce_mean(g) for g in gs] grads2 = [Z.allreduce_mean(tf.square(g)) for g in gs] mom2 = tf.maximum(0., 1. - (hvd.size() * (1 - mom2))) t = tf.Variable(1., 'adam_t') lr_t = lr * tf.sqrt((1. - tf.pow(mom2, t))) / (1. - tf.pow(mom1, t)) updates.append(t.assign_add(1)) for p, g1, g2 in zip(params, grads1, grads2): mg = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_mg') if mom1 > 0: v = tf.Variable(tf.zeros(p.get_shape()), p.name + '_adam_v') v_t = mom1 * v + (1. - mom1) * g1 updates.append(v.assign(v_t)) else: v_t = g1 mg_t = mom2 * mg + (1. - mom2) * g2 delta_t = v_t / (tf.sqrt(mg_t) + epsilon) p_t = p - lr_t * delta_t updates.append(mg.assign(mg_t)) updates.append(p.assign(p_t)) return tf.group(*updates)
def f_loss(iterators, is_training, reuse=False, init=False): if hps.direct_iterator and iterators is not None: raise NotImplementedError() else: x_A, y_A, x_B, y_B = X_A, Y_A, X_B, Y_B (bits_x_A, bits_y_A, pred_loss_A, eps_flatten_A, bits_x_B, bits_y_B, pred_loss_B, eps_flatten_B, code_loss) = _f_loss( x_A, y_A, x_B, y_B, is_training, reuse, init) local_loss_A = hps.mle_loss_scale * bits_x_A + hps.weight_y * bits_y_A local_loss_B = hps.mle_loss_scale * bits_x_B + hps.weight_y * bits_y_B # Add code difference loss if hps.joint_train: local_loss_A += hps.code_loss_scale * code_loss local_loss_B += hps.code_loss_scale * code_loss stats_A = [local_loss_A, bits_x_A, bits_y_A, pred_loss_A, code_loss] stats_B = [local_loss_B, bits_x_B, bits_y_B, pred_loss_B, code_loss] global_stats_A = Z.allreduce_mean( tf.stack([tf.reduce_mean(i) for i in stats_A])) global_stats_B = Z.allreduce_mean( tf.stack([tf.reduce_mean(i) for i in stats_B])) if hps.joint_train and is_training: return (tf.reduce_mean(local_loss_A), global_stats_A, eps_flatten_A, tf.reduce_mean(local_loss_B), global_stats_B, eps_flatten_B) else: return (tf.reduce_mean(local_loss_A), global_stats_A, tf.reduce_mean(local_loss_B), global_stats_B)
def adam(params, cost_or_grads, alpha=3e-4, hps=None, epsilon=1e-8): updates = [] if type(cost_or_grads) is not list: gs = tf.gradients(cost_or_grads, params) else: gs = cost_or_grads beta2 = 1-1./(hps.train_its*hps.polyak_epochs) # all-reduce grads = [Z.allreduce_mean(g) for g in gs] t = tf.Variable(1., 'adam_t') alpha_t = alpha * tf.sqrt((1. - tf.pow(beta2, t))) / \ (1. - tf.pow(hps.beta1, t)) updates.append(t.assign_add(1)) for w, g in zip(params, grads): mom2 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m2') if hps.beta1 > 0: mom1 = tf.Variable(tf.zeros(w.get_shape()), w.name + '_adam_m1') mom1_new = hps.beta1 * mom1 + (1. - hps.beta1) * g updates.append(mom1.assign(mom1_new)) else: mom1_new = g m2_new = beta2 * mom2 + (1. - beta2) * tf.square(g) delta_t = mom1_new / (tf.sqrt(m2_new) + epsilon) w_new = hps.weight_decay * w - alpha_t * delta_t updates.append(mom2.assign(m2_new)) updates.append(w.assign(w_new)) # Polyak averaging polyak_avg_op, polyak_swap_op, ema = polyak(params, beta2) train_op = tf.group(polyak_avg_op, *updates) return train_op, polyak_swap_op, ema
def ops(): hvd.init() by2 = tf.random.uniform((2, 2)) tallreduce_mean_sum = (tfops.allreduce_sum(by2) * tfops.allreduce_mean(by2)) / sum(by2.shape) tfops.int_shape(by2) tfops.actnorm("by2", by2)
def f_loss(iterator, is_training, reuse=False): if hps.direct_iterator and iterator is not None: x, y = iterator.get_next() else: x, y = X, Y bits_x, bits_y, pred_loss = _f_loss(x, y, is_training, reuse) local_loss = bits_x + hps.weight_y * bits_y stats = [local_loss, bits_x, bits_y, pred_loss] global_stats = Z.allreduce_mean( tf.stack([tf.reduce_mean(i) for i in stats])) return tf.reduce_mean(local_loss), global_stats
def f_loss_complete(data, label, is_training, reuse=False): x, y = data, label bits_x, bits_y, pred_loss, nobj = _f_loss(x, y, is_training, reuse) local_loss = bits_x + hps.weight_y * bits_y stats = [local_loss, bits_x, bits_y, pred_loss] global_stats = Z.allreduce_mean( tf.stack([tf.reduce_mean(i) for i in stats])) nobj_loss, stats = tf.reduce_mean(nobj), global_stats return nobj_loss