def get_graph_data(dataset, isModelDataset=False): """ Given a dataset name, loads in the graphs of that dataset. Creates a specific argument object for that dataset to allow for different datasets to be created. return: If isModelDataset: returns args, train, val, test split else: returgn args, data """ args_data = Args() args_data.change_dataset(dataset) # Load the graph data - Consider using presaved datasets! with graph load list graphs = create(args_data) graphs_len = len(graphs) random.seed(123) shuffle(graphs) # Display some graph stats graph_node_avg = 0 for graph in graphs: graph_node_avg += graph.number_of_nodes() graph_node_avg /= graphs_len print('Average num nodes', graph_node_avg) args_data.max_num_node = max( [graphs[i].number_of_nodes() for i in range(graphs_len)]) max_num_edge = max( [graphs[i].number_of_edges() for i in range(graphs_len)]) min_num_edge = min( [graphs[i].number_of_edges() for i in range(graphs_len)]) # show graphs statistics print('total graph num: {}'.format(graphs_len)) print('max number node: {}'.format(args_data.max_num_node)) print('max/min number edge: {}; {}'.format(max_num_edge, min_num_edge)) print('max previous node: {}'.format(args_data.max_prev_node)) if isModelDataset: # split datasets graphs_len = len(graphs) graphs_test = graphs[int(0.8 * graphs_len):] graphs_train = graphs[0:int(0.8 * graphs_len)] graphs_validate = graphs[0:int(0.2 * graphs_len)] return args_data, graphs_train, graphs_validate, graphs_test return args_data, graphs
def create_name(name): arg_temp = Args() arg_temp.change_dataset(name) return create(arg_temp)