Esempio n. 1
0
 def _build_request_ops(self, target, variables):
     if self._fuse_requests:
         var_fused = fuse(variables)
         other_peer_var_fused = request_variable(
             target,
             version=None,
             name=self._fused_model_name,
             shape=var_fused.shape,
             dtype=var_fused.dtype)
         return defuse(other_peer_var_fused, [v.shape for v in variables])
     else:
         return [
             request_variable_with_template(target, v) for v in variables
         ]
def test_save_and_request():
    global_step = tf.Variable(tf.constant(0, dtype=tf.int64))
    target = tf.Variable(tf.constant(0, dtype=tf.int32))

    x = tf.Variable(tf.zeros([10], dtype=tf.int32))

    inc_op = tf.assign_add(global_step, 1)
    update_op = tf.assign(x, x + 1)
    save_op = save_variable(x, version=global_step)
    y = request_variable(target, global_step, x.name, x.shape, x.dtype)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(3):
            sess.run([inc_op, update_op])
            sess.run(save_op)
            sess.run(barrier())
            v = sess.run(y)
            assert v[0] == i + 1
        sess.run(barrier())