コード例 #1
0
 def _build_request_and_save_ops(self, target, variables):
     var_fused = fuse(variables)
     save_model_op = save_variable(var_fused)
     other_peer_var_fused = request_variable_with_template(
         target, var_fused)
     other_peer_vars = defuse(other_peer_var_fused,
                              [v.shape for v in variables])
     self._save_model_op = save_model_op  # save for _get_initializer_op
     return other_peer_vars, save_model_op
コード例 #2
0
def test_save_and_request():
    global_step = tf.Variable(tf.constant(0, dtype=tf.int64))
    target = tf.Variable(tf.constant(0, dtype=tf.int32))

    x = tf.Variable(tf.zeros([10], dtype=tf.int32))

    inc_op = tf.assign_add(global_step, 1)
    update_op = tf.assign(x, x + 1)
    save_op = save_variable(x, version=global_step)
    y = request_variable(target, global_step, x.name, x.shape, x.dtype)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        for i in range(3):
            sess.run([inc_op, update_op])
            sess.run(save_op)
            sess.run(barrier())
            v = sess.run(y)
            assert v[0] == i + 1
        sess.run(barrier())
コード例 #3
0
 def _build_save_op(self, variables):
     if self._fuse_requests:
         var_fused = fuse(variables)
         return save_variable(var_fused, name=self._fused_model_name)
     else:
         return tf.group([save_variable(v) for v in variables])