Example #1
0
    def test_actnorm_bias_init_conv(self):
        images_np = np.random.rand(8, 32, 32, 3)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)

        layer = fl.ActnormBiasLayer()

        with self.assertRaises(AssertionError):
            layer.get_ddi_init_ops()

        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        init_ops = layer.get_ddi_init_ops()
        self.assertEqual(z, None)

        with self.test_session() as sess:
            # initialize network
            sess.run(tf.global_variables_initializer())
            sess.run(init_ops)
            x, logdet = sess.run([x, logdet])
            bias = sess.run(layer._bias_t)

            self.assertEqual(x.shape, images_np.shape)
            self.assertEqual(bias.shape, (1, 1, 1, 3))
            self.assertGreater(np.sum(bias**2), 0)
            # check mean after passing act norm

            self.assertAllClose(np.mean(x.reshape([-1, 3]), axis=0),
                                [0.0, 0.0, 0.0])

        self.forward_inverse(layer, flow)
Example #2
0
    def test_actnorm_scale_init_conv_iter(self):
        np.random.seed(52321)
        images_ph = tf.placeholder(tf.float32, shape=[8, 32, 32, 3])

        flow = fl.InputLayer(images_ph)

        layer = fl.ActnormScaleLayer(scale=np.sqrt(np.pi))

        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow
        init_ops = layer.get_ddi_init_ops(num_init_iterations=50)

        with self.test_session() as sess:
            # initialize network
            sess.run(tf.global_variables_initializer())
            for i in range(200):
                sess.run(init_ops,
                         feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})

            for i in range(5):
                x_np, logdet_np = sess.run(
                    [x, logdet],
                    feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})

                self.assertEqual(x.shape, x_np.shape)
                self.assertAllClose(
                    np.var(x_np.reshape([-1, 3]), axis=0),
                    [np.pi, np.pi, np.pi],
                    atol=0.1,
                )

        self.forward_inverse(
            layer, flow, feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})
Example #3
0
    def test_actnorm_scale_init_conv(self):
        np.random.seed(52321)
        images_np = np.random.rand(8, 32, 32, 3)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)

        layer = fl.ActnormScaleLayer()

        with self.assertRaises(AssertionError):
            layer.get_ddi_init_ops()

        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        init_ops = layer.get_ddi_init_ops()
        self.assertEqual(z, None)

        with self.test_session() as sess:
            # initialize network
            sess.run(tf.global_variables_initializer())
            sess.run(init_ops)
            x, logdet = sess.run([x, logdet])
            log_scale = sess.run(layer._log_scale_t)

            self.assertEqual(x.shape, images_np.shape)
            self.assertEqual(log_scale.shape, (1, 1, 1, 3))
            self.assertGreater(np.sum(log_scale**2), 0)
            self.assertGreater(np.sum(logdet**2), 0)
            # check var after passing act norm
            self.assertAllClose(np.var(x.reshape([-1, 3]), axis=0),
                                [1.0, 1.0, 1.0],
                                atol=0.001)

        self.forward_inverse(layer, flow)
Example #4
0
    def test_actnorm_init_conv(self):
        np.random.seed(52321)
        images_np = np.random.rand(8, 32, 32, 3)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)

        layer = fl.ActnormLayer(scale=np.sqrt(np.pi))

        with self.assertRaises(AssertionError):
            layer.get_ddi_init_ops()

        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        init_ops = layer.get_ddi_init_ops()
        self.assertEqual(z, None)

        with self.test_session() as sess:
            # initialize network
            sess.run(tf.global_variables_initializer())
            sess.run(init_ops)
            x, logdet = sess.run([x, logdet])

            self.assertEqual(x.shape, images_np.shape)
            # check var after passing act norm
            self.assertAllClose(np.var(x.reshape([-1, 3]), axis=0),
                                [np.pi] * 3,
                                atol=0.01)
            self.assertAllClose(np.mean(x.reshape([-1, 3]), axis=0), [0.0] * 3,
                                atol=0.01)

        self.forward_inverse(layer, flow)
Example #5
0
    def test_logitify_layer_conv(self):
        np.random.seed(52321)
        images_np = np.random.rand(8, 32, 32, 1)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)

        layer = fl.LogitifyImage()
        self.forward_inverse(layer, flow, atol=0.01)

        x, logdet, z = flow
        logdet += 10.0
        flow = x, logdet, z

        new_flow = layer(flow, forward=True)
        flow_rec = layer(new_flow, forward=False)
        x, logdet, z = new_flow
        x_rec, logdet_rec, z = flow_rec

        self.assertEqual(z, None)
        self.assertEqual(x.shape.as_list(), [8, 32, 32, 1])
        self.assertEqual(logdet.shape.as_list(), [8])
        self.assertEqual(x_rec.shape.as_list(), [8, 32, 32, 1])
        self.assertEqual(logdet_rec.shape.as_list(), [8])

        with self.test_session() as sess:
            _ = sess.run([x, logdet])
            x_rec, logdet_rec = sess.run([x_rec, logdet_rec])
            self.assertAllClose(logdet_rec, [10.0] * 8, atol=0.01)
            self.assertAllClose(images_np, x_rec, atol=0.01)
Example #6
0
    def test_actnorm_scale_conv(self):

        images_np = np.random.rand(8, 32, 32, 3)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)

        layer = fl.ActnormScaleLayer()
        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        self.assertEqual(z, None)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            log_scale = sess.run(layer._log_scale_t)
            x, logdet = sess.run([x, logdet])

            self.assertEqual(x.shape, images_np.shape)
            self.assertEqual(log_scale.shape, (1, 1, 1, 3))
            self.assertEqual(np.sum(log_scale**2), 0)
            # zero initialization
            self.assertAllClose(logdet, [0.0] * 8)

        self.forward_inverse(layer, flow)
Example #7
0
    def test_invertible_conv1x1_learn_identity(self):

        images = tf.random_normal([8, 32, 32, 16])
        flow = fl.InputLayer(images)
        layer = fl.InvertibleConv1x1Layer(use_lu_decomposition=True)
        self.try_to_train_identity_layer(layer, flow)

        layer = fl.InvertibleConv1x1Layer(use_lu_decomposition=False)
        self.try_to_train_identity_layer(layer, flow)
Example #8
0
    def test_input_layer(self):
        images = np.random.rand(8, 32, 32, 1)
        images = tf.to_float(images)

        x, logdet, z = fl.InputLayer(images)

        self.assertEqual(images, x)
        self.assertEqual(logdet.shape.as_list(), [8])
        self.assertEqual(z, None)
Example #9
0
File: Glow.py Project: geosada/LVAT
    def encoder(self, x):

        #if x is None:
        #    x = self.x_image_ph
        #flow = fl.InputLayer(self.x_image_ph)

        flow = fl.InputLayer(x)
        output_flow = self.model_flow(flow, forward=True)

        # ## Prepare output tensors

        y, logdet, z = output_flow
        return y, logdet, z
Example #10
0
    def test_squeezing_layer_conv(self):
        images = np.random.rand(8, 32, 32, 1)
        images = tf.to_float(images)

        flow = fl.InputLayer(images)

        layer = fl.SqueezingLayer()
        self.forward_inverse(layer, flow)

        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        self.assertEqual(logdet.shape.as_list(), [8])
        self.assertEqual([8, 16, 16, 4], x.shape.as_list())
Example #11
0
    def test_quantize_image_layer_conv(self):
        np.random.seed(52321)
        images_np = np.random.rand(8, 32, 32, 3)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)
        layer = fl.QuantizeImage(num_bits=8)
        self.forward_inverse(layer, flow, atol=1.5 / 256)

        new_flow = layer(flow, forward=True)
        flow_rec = layer(new_flow, forward=False)
        x, logdet, z = new_flow
        x_rec, logdet_rec, z = flow_rec

        self.assertEqual(z, None)
        self.assertEqual(x.shape.as_list(), [8, 32, 32, 3])
        self.assertEqual(x_rec.shape.as_list(), [8, 32, 32, 3])
        # less bits
        flow = fl.InputLayer(images)
        layer = fl.QuantizeImage(num_bits=5)
        self.forward_inverse(layer, flow, atol=1.5 / 32)

        layer = fl.QuantizeImage(num_bits=4)
        new_flow = layer(flow, forward=True)
        flow_rec = layer(new_flow, forward=False)

        with self.test_session() as sess:
            x_rec_uint8 = layer.to_uint8(flow_rec[0])
            self.assertEqual(x_rec_uint8.dtype, tf.uint8)
            x_rec_uint8 = sess.run(x_rec_uint8)
            self.assertAllGreaterEqual(x_rec_uint8, 0)
            self.assertAllLessEqual(x_rec_uint8, 255)
            self.assertEqual(np.unique(x_rec_uint8).shape, (2**4, ))

        with self.assertRaises(AssertionError):
            layer = fl.QuantizeImage(num_bits=4)
            self.forward_inverse(layer, flow, atol=1 / 32)
Example #12
0
    def test_actnorm_init_conv_iter(self):
        np.random.seed(52321)
        images_ph = tf.placeholder(tf.float32, shape=[8, 32, 32, 3])

        flow = fl.InputLayer(images_ph)

        layer = fl.ActnormLayer(scale=np.sqrt(np.pi))

        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow
        init_ops = layer.get_ddi_init_ops(num_init_iterations=50)

        with self.test_session() as sess:
            # initialize network
            sess.run(tf.global_variables_initializer())
            for i in range(200):
                sess.run(init_ops,
                         feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})

            for i in range(5):
                x_np, logdet_np = sess.run(
                    [x, logdet],
                    feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})
                self.assertEqual(x.shape, x_np.shape)
                self.assertAllClose(np.var(x_np.reshape([-1, 3]), axis=0),
                                    [np.pi] * 3,
                                    atol=0.1)
                self.assertAllClose(np.mean(x_np.reshape([-1, 3]), axis=0),
                                    [0.0] * 3,
                                    atol=0.1)

        self.forward_inverse(
            layer, flow, feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})

        def feed_dict_fn():
            return {images_ph: np.random.rand(8, 32, 32, 3)}

        def post_init_fn(sess):
            init_ops = layer.get_ddi_init_ops()
            sess.run(init_ops, {images_ph: np.random.rand(8, 32, 32, 3)})

        with tf.variable_scope("TestTrain"):
            layer = fl.ActnormLayer(scale=np.sqrt(np.pi))
            self.try_to_train_identity_layer(layer,
                                             flow,
                                             feed_dict_fn=feed_dict_fn,
                                             post_init_fn=post_init_fn)
    def test_create_simple_flow(self):
        np.random.seed(642201)

        images = tf.placeholder(tf.float32, [16, 32, 32, 1])
        layers, actnorm_layers = nets.create_simple_flow(num_steps=2,
                                                         num_scales=4,
                                                         num_bits=8)
        flow = fl.InputLayer(images)
        model_flow = fl.ChainLayer(layers)
        output_flow = model_flow(flow, forward=True)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            for actnorm_layer in actnorm_layers:
                init_op = actnorm_layer.get_ddi_init_ops(10)
                noise = np.random.rand(16, 32, 32, 1)
                # fit actnorms to certain noise
                for i in range(30):
                    sess.run(init_op, feed_dict={images: noise})

                actnorm_flow = actnorm_layer._forward_outputs[0]
                normed_x = sess.run(actnorm_flow[0], feed_dict={images: noise})
                nc = normed_x.shape[-1]

                self.assertAllClose(np.var(normed_x.reshape([-1, nc]), axis=0),
                                    [1.0] * nc,
                                    atol=0.1)
                self.assertAllClose(np.mean(normed_x.reshape([-1, nc]),
                                            axis=0), [0.0] * nc,
                                    atol=0.1)

            output_flow_np = sess.run(
                output_flow, feed_dict={images: np.random.rand(16, 32, 32, 1)})

            y, logdet, z = output_flow_np

            self.assertEqual(
                np.prod(y.shape) + np.prod(z.shape), np.prod([16, 32, 32, 1]))
            self.assertTrue(np.max(np.abs(y)) < 15.0)
            self.forward_inverse(
                sess,
                model_flow,
                flow,
                atol=0.01,
                feed_dict={images: np.random.rand(16, 32, 32, 1)},
            )
Example #14
0
    def test_invertible_conv1x1_lu_decomp(self):

        images_np = np.random.rand(8, 32, 32, 16)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)

        layer = fl.InvertibleConv1x1Layer(use_lu_decomposition=True)
        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        self.assertEqual(z, None)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            x, logdet = sess.run([x, logdet])
            self.assertEqual(x.shape, images_np.shape)

        self.forward_inverse(layer, flow)
Example #15
0
    def test_actnorm_conv(self):

        images_np = np.random.rand(8, 32, 32, 3)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)

        layer = fl.ActnormLayer()
        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        self.assertEqual(z, None)

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            x, logdet = sess.run([x, logdet])
            self.assertEqual(x.shape, images_np.shape)

        self.forward_inverse(layer, flow)
    def test_initialize_actnorms(self):

        np.random.seed(642201)
        images_ph = tf.placeholder(tf.float32, [16, 16, 16, 1])

        layers, actnorm_layers = nets.create_simple_flow(num_steps=1,
                                                         num_scales=3)
        flow = fl.InputLayer(images_ph)
        model_flow = fl.ChainLayer(layers)
        output_flow = model_flow(flow, forward=True)

        noise = np.random.rand(16, 16, 16, 1)

        def feed_dict_fn():
            return {images_ph: noise}

        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())

            nets.initialize_actnorms(
                sess,
                feed_dict_fn=feed_dict_fn,
                actnorm_layers=actnorm_layers,
                num_steps=50,
            )

            for actnorm_layer in actnorm_layers:

                actnorm_flow = actnorm_layer._forward_outputs[0]
                normed_x = sess.run(actnorm_flow[0],
                                    feed_dict={images_ph: noise})
                nc = normed_x.shape[-1]

                self.assertAllClose(np.var(normed_x.reshape([-1, nc]), axis=0),
                                    [1.0] * nc,
                                    atol=0.1)
                self.assertAllClose(np.mean(normed_x.reshape([-1, nc]),
                                            axis=0), [0.0] * nc,
                                    atol=0.1)
Example #17
0
    def test_simple_affine_coupling_layer(self):

        images_np = np.random.rand(8, 32, 32, 16)
        images = tf.to_float(images_np)
        flow = fl.InputLayer(images)
        layer = fl.AffineCouplingLayer(
            _shift_and_log_scale_fn_template("test"))
        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow

        self.assertEqual(z, None)
        with self.test_session() as sess:
            sess.run(tf.global_variables_initializer())
            x, logdet = sess.run([x, logdet])
            self.assertEqual(x.shape, images_np.shape)
            self.assertEqual(logdet.shape, (8, ))

        self.forward_inverse(layer, flow)

        layer = fl.AffineCouplingLayer(
            _shift_and_log_scale_fn_template("train"))
        self.try_to_train_identity_layer(layer, flow)
Example #18
0
    def test_factor_out_layer_conv(self):
        images_np = np.random.rand(8, 32, 32, 16)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)
        layer = fl.FactorOutLayer()
        self.forward_inverse(layer, flow)
        with tf_framework.arg_scope([fl.FlowLayer.__call__], forward=True):
            new_flow = layer(flow)
        x, logdet, z = new_flow

        self.assertEqual([8, 32, 32, 8], x.shape.as_list())
        self.assertEqual([8, 32, 32, 8], z.shape.as_list())

        with self.test_session() as sess:
            new_flow_np = sess.run(new_flow)
            # x
            self.assertAllClose(new_flow_np[0], images_np[:, :, :, 8:])
            # z
            self.assertAllClose(new_flow_np[2], images_np[:, :, :, :8])
            # logdet
            self.assertAllClose(new_flow_np[1], np.zeros_like(new_flow_np[1]))
Example #19
0
    def test_actnorm_bias_init_conv_iter(self):

        images_ph = tf.placeholder(tf.float32, shape=[8, 32, 32, 3])

        flow = fl.InputLayer(images_ph)

        layer = fl.ActnormBiasLayer()

        new_flow = layer(flow, forward=True)
        x, logdet, z = new_flow
        init_ops = layer.get_ddi_init_ops(num_init_iterations=100)

        with self.test_session() as sess:
            # initialize network
            sess.run(tf.global_variables_initializer())
            for i in range(200):
                sess.run(init_ops,
                         feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})
                # print(sess.run(layer._bias_t))
            x_np_values = []
            for i in range(20):
                x_np, logdet_np = sess.run(
                    [x, logdet],
                    feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})

                x_np_values.append(x_np)

            x_np_values = np.array(x_np_values).mean(0)

            self.assertEqual(x.shape, x_np_values.shape)

            self.assertAllClose(
                np.mean(x_np_values.reshape([-1, 3]), axis=0),
                [0.0, 0.0, 0.0],
                atol=0.05,
            )

        self.forward_inverse(
            layer, flow, feed_dict={images_ph: np.random.rand(8, 32, 32, 3)})
Example #20
0
    def test_combine_squeeze_and_factor_layers_conv(self):

        images_np = np.random.rand(8, 32, 32, 1)
        images = tf.to_float(images_np)

        flow = fl.InputLayer(images)
        # in comments are output shapes
        layers = [
            fl.SqueezingLayer(),  # x=[8, 16, 16, 4]
            fl.FactorOutLayer(),  # x=[8, 16, 16, 2]
            fl.SqueezingLayer(),  # x=[8, 8, 8, 8]
            fl.FactorOutLayer(),  # x=[8, 8, 8, 4] z=[8, 8, 8, 12]
        ]

        chain = fl.ChainLayer(layers=layers)
        print()
        with tf_framework.arg_scope([fl.FlowLayer.__call__], forward=True):
            new_flow = chain(flow)
            with self.test_session() as sess:
                x, logdet, z = sess.run(new_flow)
                self.assertEqual(x.shape, (8, 8, 8, 4))
                self.assertEqual(z.shape, (8, 8, 8, 12))

        self.forward_inverse(chain, flow)
Example #21
0
    def model_fn(features, labels, mode, params):

        nn_template_fn = nets.OpenAITemplate(
            width=args.width
        )

        layers, actnorm_layers = nets.create_simple_flow(
            num_steps=args.num_steps,
            num_scales=args.num_scales,
            num_bits=args.num_bits,
            template_fn=nn_template_fn
        )

        images = features["images"]
        flow = fl.InputLayer(images)
        model_flow = fl.ChainLayer(layers)
        output_flow = model_flow(flow, forward=True)
        y, logdet, z = output_flow

        for layer in actnorm_layers:
            init_op = layer.get_ddi_init_ops(30)
            tf.add_to_collection(ACTNORM_INIT_OPS, init_op)

        total_params = 0
        trainable_variables = tf.trainable_variables()
        for k, v in enumerate(trainable_variables):
            num_params = np.prod(v.shape.as_list())
            total_params += num_params

        print(f"TOTAL PARAMS: {total_params/1e6} [M]")

        if mode == tf.estimator.ModeKeys.PREDICT:
            raise NotImplementedError()

        tfd = tf.contrib.distributions
        y_flatten = tf.reshape(y, [batch_size, -1])
        z_flatten = tf.reshape(z, [batch_size, -1])

        prior_y = tfd.MultivariateNormalDiag(loc=tf.zeros_like(y_flatten),
                                             scale_diag=tf.ones_like(y_flatten))
        prior_z = tfd.MultivariateNormalDiag(loc=tf.zeros_like(z_flatten),
                                             scale_diag=tf.ones_like(z_flatten))

        log_prob_y = prior_y.log_prob(y_flatten)
        log_prob_z = prior_z.log_prob(z_flatten)

        loss = log_prob_y + log_prob_z + logdet
        # compute loss per pixel, the final loss should be same
        # for different input sizes
        loss = - tf.reduce_mean(loss) / image_size / image_size

        trainable_variables = tf.trainable_variables()
        l2_loss = l2_reg * tf.add_n(
            [tf.nn.l2_loss(v) for v in trainable_variables])

        total_loss = l2_loss + loss

        tf.summary.scalar('total_loss', total_loss)
        tf.summary.scalar('loss', loss)
        tf.summary.scalar('l2_loss', l2_loss)

        # Sampling during training and evaluation
        prior_y = tfd.MultivariateNormalDiag(loc=tf.zeros_like(y_flatten),
                                             scale_diag=sample_beta * tf.ones_like(y_flatten))
        prior_z = tfd.MultivariateNormalDiag(loc=tf.zeros_like(z_flatten),
                                             scale_diag=sample_beta * tf.ones_like(z_flatten))

        sample_y_flatten = prior_y.sample()
        sample_y = tf.reshape(sample_y_flatten, y.shape.as_list())
        sample_z = tf.reshape(prior_z.sample(), z.shape.as_list())
        sampled_logdet = prior_y.log_prob(sample_y_flatten)

        inverse_flow = sample_y, sampled_logdet, sample_z
        sampled_flow = model_flow(inverse_flow, forward=False)
        x_flow_sampled, _, _ = sampled_flow
        # convert to uint8
        quantize_image_layer = layers[0]
        x_flow_sampled_uint = quantize_image_layer.to_uint8(x_flow_sampled)

        grid_image = tf.contrib.gan.eval.image_grid(
            x_flow_sampled_uint,
            grid_shape=[4, batch_size // 4],
            image_shape=(image_size, image_size),
            num_channels=3
        )

        grid_summary = tf.summary.image(
            f'samples{sample_beta}', grid_image, max_outputs=10
        )

        if mode == tf.estimator.ModeKeys.EVAL:
            eval_summary_hook = tf.train.SummarySaverHook(
                save_steps=1,
                output_dir=args.model_dir + "/eval",
                summary_op=grid_summary
            )

            return tf.estimator.EstimatorSpec(
                mode,
                loss=total_loss,
                evaluation_hooks=[eval_summary_hook]
            )

        # Create training op.
        assert mode == tf.estimator.ModeKeys.TRAIN

        train_summary_hook = tf.train.SummarySaverHook(
            save_secs=args.save_secs,
            output_dir=args.model_dir,
            summary_op=grid_summary
        )

        global_step = tf.train.get_global_step()
        learning_rate = tf.train.inverse_time_decay(
            args.lr, global_step, args.decay_steps, args.decay_rate,
            staircase=True
        )

        tf.summary.scalar('learning_rate', learning_rate)

        optimizer = tf.train.AdagradOptimizer(learning_rate=learning_rate)
        if args.clip > 0.0:
            gvs = optimizer.compute_gradients(total_loss)
            capped_gvs = [
                (tf.clip_by_value(grad, -args.clip, args.clip), var) for grad, var in gvs
            ]
            train_op = optimizer.apply_gradients(capped_gvs, global_step=global_step)
        else:
            train_op = optimizer.minimize(total_loss, global_step=global_step)

        return tf.estimator.EstimatorSpec(
            mode, loss=total_loss,
            train_op=train_op, training_hooks=[train_summary_hook]
        )