def preprocessing_fn(inputs): def repeat(in_tensor, value): batch_size = tf.shape(in_tensor)[0] return tf.ones([batch_size], value.dtype) * value return { 'min': tft.map(repeat, inputs['a'], tft.min(inputs['a'])), 'max': tft.map(repeat, inputs['a'], tft.max(inputs['a'])), 'sum': tft.map(repeat, inputs['a'], tft.sum(inputs['a'])), 'size': tft.map(repeat, inputs['a'], tft.size(inputs['a'])), 'mean': tft.map(repeat, inputs['a'], tft.mean(inputs['a'])) }
def size_fn(inputs): return { 'size': tft.map(repeat, inputs['a'], tft.size(inputs['a'])) }