コード例 #1
0
def get_dataset(CONFIG, X, Y, z_r):
    data = gprn.Dataset()
    num_data_sources = X.shape[0]

    r = [0, 1, 2]

    #for i in range(num_data_sources):
    for i in r:
        x = X[i]
        y = Y[i]
        print('dataset: ', i, ' ', x.shape)

        M = x.shape[1]
        b = y.shape[0]

        data.add_source_dict({
            'M': M,
            'x': x,
            'y': y,
            #'z': x,
            'batch_size': b,
            'active_tasks': [[0], [0], [0]]
        })

    data.add_inducing_points(z_r)
    return data
コード例 #2
0
    def get_data(self, N):
        np.random.seed(0)
        self.x = np.expand_dims(np.random.random(N), -1).astype(np.float32)
        self.y = np.expand_dims(np.random.random(N), -1).astype(np.float32)

        data = gprn.Dataset()
        data.add_source_dict({'x': self.x, 'y': self.y, 'batch_size': None})
        return data
コード例 #3
0
ファイル: m_cmgp.py プロジェクト: ohamelijnck/multi_res_gps
def get_dataset(X, Y, z_r):
    data = gprn.Dataset()
    num_data_sources = X.shape[0]

    #for i in range(num_data_sources):
    for i in [0, 1, 2]:
        x = X[i]
        y = Y[i]
        print('dataset: ', i, ' ', x.shape)

        M = x.shape[1]

        data.add_source_dict({
            'M': M,
            'x': x,
            'y': y,
            #'z': x,
            'batch_size': y.shape[0]
        })

    data.add_inducing_points(z_r)
    return data