示例#1
0
def batch_dataflow(df, batch_size):
    """
    The function builds batch dataflow from the input dataflow of samples

    :param df: dataflow of samples
    :param batch_size: batch size
    :return: dataflow of batches
    """
    df = BatchData(df, batch_size, use_list=False)
    df = MapData(df, lambda x: ([x[0]], [x[2]]))
    df.reset_state()
    return df
示例#2
0
def batch_dataflow(df,
                   batch_size,
                   time_steps=4,
                   num_stages=6,
                   format=['heatpaf', 'last']):
    informat, outformat = format

    df = BatchData(df, batch_size, use_list=False)

    def in_heat(x):
        return [
            np.stack([x[0]] * time_steps, axis=1),
            np.stack([x[2]] * time_steps, axis=1)
        ]

    def in_heatpaf(x):
        return [
            np.stack([x[0]] * time_steps, axis=1),
            np.stack([x[1]] * time_steps, axis=1),
            np.stack([x[2]] * time_steps, axis=1)
        ]

    def out_heat_last(x):
        return [np.stack([x[4]] * time_steps, axis=1)] * num_stages

    def out_heatpaf_last(x):
        return [
            np.stack([x[3]] * time_steps, axis=1),
            np.stack([x[4]] * time_steps, axis=1),
            np.stack([x[3]] * time_steps, axis=1),
            np.stack([x[4]] * time_steps, axis=1),  # TD layers end here
            x[3],  # TD layers are joined here by LSTM
            x[4],
            x[3],  # these last outputs collapse to one timestep output
            x[4],
            x[3],
            x[4],
            x[3],
            x[4],
        ]

    if informat == 'heat' and outformat == 'last':
        df = MapData(df, lambda x: (heat_only(x), out_heat_last(x)))
    elif informat == 'heatpaf' and outformat == 'last':
        df = MapData(df, lambda x: (in_heatpaf(x), out_heatpaf_last(x)))
    else:
        raise Exception('Unknown format requested: %s' % format)

    df.reset_state()
    return df