コード例 #1
0
ファイル: train.py プロジェクト: zhhhzhang/ASTGCN
    raise SystemExit("params folder exists! select a new params path please")

class MyInit(mx.init.Initializer):
    xavier = mx.init.Xavier()
    uniform = mx.init.Uniform()
    def _init_weight(self, name, data):
        if len(data.shape) < 2:
            self.uniform._init_weight(name, data)
            print('Init', name, data.shape, 'with Uniform')
        else:
            self.xavier._init_weight(name, data)
            print('Init', name, data.shape, 'with Xavier')

if __name__ == "__main__":
    # read all data from graph singal matrix file
    all_data = read_and_generate_dataset(graph_signal_matrix_filename, num_of_vertices, num_of_features, num_of_weeks, num_of_days, num_of_hours, points_per_hour, num_for_predict)

    # test set ground truth
    true_value = all_data['test']['target'].transpose((0, 2, 1)).reshape(all_data['test']['target'].shape[0], -1)

    # training set data loader
    train_loader = gluon.data.DataLoader(
                        gluon.data.ArrayDataset(
                            nd.array(all_data['train']['week'], ctx = ctx),
                            nd.array(all_data['train']['day'], ctx = ctx),
                            nd.array(all_data['train']['recent'], ctx = ctx),
                            nd.array(all_data['train']['target'], ctx = ctx)
                        ),
                        batch_size = batch_size,
                        shuffle = True
    )
コード例 #2
0
ファイル: Gated_STGCN.py プロジェクト: Orchid0/DGCN
# check parameters file
if os.path.exists(params_path) and not FLAGS.force:
    raise SystemExit("Params folder exists! Select a new params path please!")
else:
    if os.path.exists(params_path):
        shutil.rmtree(params_path)
    os.makedirs(params_path)
    print('Create params directory %s' % (params_path))

if __name__ == "__main__":
    # read all data from graph signal matrix file
    print("Reading data...")
    #Input: train / valid  / test : length x 3 x NUM_POINT x 12
    all_data = read_and_generate_dataset(graph_signal_matrix_filename,
                                         num_of_weeks, num_of_days,
                                         num_of_hours, num_for_predict,
                                         points_per_hour, merge)

    # test set ground truth
    true_value = all_data['test']['target']
    print(true_value.shape)

    # training set data loader
    train_loader = DataLoader(TensorDataset(
        torch.Tensor(all_data['train']['week']),
        torch.Tensor(all_data['train']['day']),
        torch.Tensor(all_data['train']['recent']),
        torch.Tensor(all_data['train']['target'])),
                              batch_size=batch_size,
                              shuffle=True)