def forward_pass(self, batch, training=True):
    del training

    y = tf.nn.softmax(
        tf.matmul(batch['x'], self._variables.weights) + self._variables.bias)
    return reconstruction_model.BatchOutput(
        predictions=y, labels=batch['y'], num_examples=tf.size(batch['y']))
    def forward_pass(self, batch_input, training=True):
        if hasattr(batch_input, '_asdict'):
            batch_input = batch_input._asdict()
        if isinstance(batch_input, collections.abc.Mapping):
            inputs = batch_input.get('x')
        else:
            inputs = batch_input[0]
        if inputs is None:
            raise KeyError(
                'Received a batch_input that is missing required key `x`. '
                'Instead have keys {}'.format(list(batch_input.keys())))
        predictions = self._keras_model(inputs, training=training)

        if isinstance(batch_input, collections.abc.Mapping):
            y_true = batch_input.get('y')
        else:
            y_true = batch_input[1]

        return reconstruction_model.BatchOutput(
            predictions=predictions,
            labels=y_true,
            num_examples=tf.shape(tf.nest.flatten(inputs)[0])[0])