def scan_reference(f, init, xs): carry = init ys = [] for x in xs: (carry, y) = f(carry, x) ys.append(tf_np.reshape(y, (1,) + y.shape)) ys = tf_np.concatenate(ys, 0) return carry, ys
def losses(scan, c, xs): c, ys = scan(f, c, xs) return tf_np.concatenate(tf.nest.flatten(tf.nest.map_structure( lambda a: tf_np.reshape(a, [-1]), (c, ys))))