def batch_dataflow(df, batch_size): """ The function builds batch dataflow from the input dataflow of samples :param df: dataflow of samples :param batch_size: batch size :return: dataflow of batches """ df = BatchData(df, batch_size, use_list=False) df = MapData(df, lambda x: ([x[0]], [x[2]])) df.reset_state() return df
def batch_dataflow(df, batch_size, time_steps=4, num_stages=6, format=['heatpaf', 'last']): informat, outformat = format df = BatchData(df, batch_size, use_list=False) def in_heat(x): return [ np.stack([x[0]] * time_steps, axis=1), np.stack([x[2]] * time_steps, axis=1) ] def in_heatpaf(x): return [ np.stack([x[0]] * time_steps, axis=1), np.stack([x[1]] * time_steps, axis=1), np.stack([x[2]] * time_steps, axis=1) ] def out_heat_last(x): return [np.stack([x[4]] * time_steps, axis=1)] * num_stages def out_heatpaf_last(x): return [ np.stack([x[3]] * time_steps, axis=1), np.stack([x[4]] * time_steps, axis=1), np.stack([x[3]] * time_steps, axis=1), np.stack([x[4]] * time_steps, axis=1), # TD layers end here x[3], # TD layers are joined here by LSTM x[4], x[3], # these last outputs collapse to one timestep output x[4], x[3], x[4], x[3], x[4], ] if informat == 'heat' and outformat == 'last': df = MapData(df, lambda x: (heat_only(x), out_heat_last(x))) elif informat == 'heatpaf' and outformat == 'last': df = MapData(df, lambda x: (in_heatpaf(x), out_heatpaf_last(x))) else: raise Exception('Unknown format requested: %s' % format) df.reset_state() return df