def forward_pass(self, batch, training=True): del training y = tf.nn.softmax( tf.matmul(batch['x'], self._variables.weights) + self._variables.bias) return reconstruction_model.BatchOutput( predictions=y, labels=batch['y'], num_examples=tf.size(batch['y']))
def forward_pass(self, batch_input, training=True): if hasattr(batch_input, '_asdict'): batch_input = batch_input._asdict() if isinstance(batch_input, collections.abc.Mapping): inputs = batch_input.get('x') else: inputs = batch_input[0] if inputs is None: raise KeyError( 'Received a batch_input that is missing required key `x`. ' 'Instead have keys {}'.format(list(batch_input.keys()))) predictions = self._keras_model(inputs, training=training) if isinstance(batch_input, collections.abc.Mapping): y_true = batch_input.get('y') else: y_true = batch_input[1] return reconstruction_model.BatchOutput( predictions=predictions, labels=y_true, num_examples=tf.shape(tf.nest.flatten(inputs)[0])[0])