def __init__(self, config: MyConfig):
        self.config = config
        with tf.device("/gpu:0"):
            self.xs = tf.placeholder(dtype=tf.float32,
                                     shape=[None, 784],
                                     name="x")  # [-1,784]
            self.ys = tf.placeholder(dtype=tf.int32, shape=[None],
                                     name="y")  # [-1]
            self.lr = tf.placeholder(dtype=tf.float32, shape=None,
                                     name="lr")  #
            self.inputs = [self.xs, self.ys, self.lr]

            x = tf.reshape(self.xs, [-1, 28, 28, 1])
            self.vec = self.encode(x, config.vec_size)
            y = self.decode(self.vec)  # [-1, 28, 28, 1]
            # 计算vec的平均值
            self.process_normal(self.vec)

            loss = tf.reduce_mean(tf.square(y - x))
            opt = tf.train.AdamOptimizer(self.lr)
            with tf.control_dependencies(
                    tf.get_collection(tf.GraphKeys.UPDATE_OPS)):
                self.train_opt = opt.minimize(loss)
            self.loss_summary = tf.summary.scalar(name="loss",
                                                  tensor=tf.sqrt(loss))
            self.precise_summary = None
            self.y = tf.reshape(y, [-1, 28, 28])
    def process_normal(self, vec):
        """"
        :param vec:  [-1, 4]
        """ ""
        mean = tf.reduce_mean(vec, axis=0)  # 当前平均值
        vec_size = vec.shape[1]
        self.final_mean = tf.get_variable(name="mean",
                                          shape=[vec_size],
                                          dtype=tf.float32,
                                          trainable=False)  # 目标平均值
        momentum = self.config.momentum
        assign = tf.assign(self.final_mean,
                           self.final_mean * momentum + mean * (1 - momentum))
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, assign)

        msd = tf.reduce_mean(tf.square(vec), axis=0)
        self.final_msd = tf.get_variable(name="msd",
                                         shape=[vec_size],
                                         dtype=tf.float32,
                                         trainable=False)
        msd_assign = tf.assign(
            self.final_msd, self.final_msd * momentum + msd * (1 - momentum))
        tf.add_to_collection(tf.GraphKeys.UPDATE_OPS, msd_assign)