def metadata_parser_fn( _, file_path: RichPath) -> Iterable[Dict[str, Any]]: raw_metadata = {"type_lattice_path": type_lattice_path} type(self)._init_metadata(self.hyperparameters, raw_metadata) for raw_sample in file_path.read_by_file_suffix(): type(self)._load_metadata_from_sample( self.hyperparameters, raw_sample=raw_sample, raw_metadata=raw_metadata) yield raw_metadata
def restore(path: RichPath, is_train: bool, hyper_overrides: Optional[Dict[str, Any]] = None, model_save_dir: Optional[str] = None, log_save_dir: Optional[str] = None) -> Model: saved_data = path.read_by_file_suffix() if hyper_overrides is not None: saved_data['hyperparameters'].update(hyper_overrides) model_class = get_model_class_from_name(saved_data['model_type']) model = model_class(hyperparameters=saved_data['hyperparameters'], run_name=saved_data.get('run_name'), model_save_dir=model_save_dir, log_save_dir=log_save_dir) model.metadata.update(saved_data['metadata']) model.make_model(is_train=is_train) variables_to_initialize = [] with model.sess.graph.as_default(): with tf.name_scope("restore"): restore_ops = [] used_vars = set() for variable in sorted(model.sess.graph.get_collection( tf.GraphKeys.GLOBAL_VARIABLES), key=lambda v: v.name): used_vars.add(variable.name) if variable.name in saved_data['weights']: # print('Initializing %s from saved value.' % variable.name) restore_ops.append( variable.assign(saved_data['weights'][variable.name])) else: print( 'Freshly initializing %s since no saved value was found.' % variable.name) variables_to_initialize.append(variable) for var_name in sorted(saved_data['weights']): if var_name not in used_vars: if var_name.endswith('Adam:0') or var_name.endswith( 'Adam_1:0') or var_name in [ 'beta1_power:0', 'beta2_power:0' ]: continue print('Saved weights for %s not used by model.' % var_name) restore_ops.append( tf.variables_initializer(variables_to_initialize)) model.sess.run(restore_ops) return model
def load_existing_metadata(self, metadata_path: RichPath): saved_data = metadata_path.read_by_file_suffix() hyper_names = set(self.hyperparameters.keys()) hyper_names.update(saved_data['hyperparameters'].keys()) if 'cg_node_type_vocab_size' in saved_data['hyperparameters']: self.hyperparameters['cg_node_type_vocab_size'] = saved_data['hyperparameters']['cg_node_type_vocab_size'] # TODO: Should not be needed for hyper_name in hyper_names: if hyper_name in ['run_id']: continue # these are supposed to change old_hyper_value = saved_data['hyperparameters'].get(hyper_name) new_hyper_value = self.hyperparameters.get(hyper_name) if old_hyper_value != new_hyper_value: self.train_log("I: Hyperparameter %s now has value '%s' but was '%s' when tensorising data." % (hyper_name, new_hyper_value, old_hyper_value)) self.__metadata = saved_data['metadata']
def get_queries(query_file: RichPath, dfs: Dict[str, pd.DataFrame]) -> List[Dict[str, Any]]: query_descriptions = query_file.read_by_file_suffix() queries: List[Dict[str, Any]] = [] for description in query_descriptions: on: List[MergeFrame] = [] for frame, left, right in zip(description['frames'][1:], description['left'], description['right']): on.append( MergeFrame(frame=dfs[frame], left_on=left, right_on=right)) query = {'from': dfs[description['frames'][0]], 'on': on} queries.append(query) return queries
def __load_data(self, data_file: RichPath) -> List[GraphSample]: print(" Loading QM9 data from %s." % (data_file, )) data = list(data_file.read_by_file_suffix( )) # list() needed for .jsonl case, where .read*() is just a generator # Get some common data out: num_fwd_edge_types = 0 for g in data: num_fwd_edge_types = max(num_fwd_edge_types, max([e[1] for e in g['graph']])) if self.params['add_self_loop_edges']: num_fwd_edge_types += 1 self.__num_edge_types = max( self.num_edge_types, num_fwd_edge_types * (1 if self.params['tie_fwd_bkwd_edges'] else 2)) self.__annotation_size = max(self.__annotation_size, len(data[0]["node_features"][0])) return self.__process_raw_graphs(data)
def split_file(input_path: RichPath, output_paths: Dict[str, RichPath], train_ratio: float, valid_ratio: float, test_ratio: float, test_only_projects: Set[str]) -> None: train_graphs, valid_graphs, test_graphs, test_only_graphs = [], [], [], [] try: for datapoint in input_path.read_by_file_suffix(): datapoint_provenance = datapoint['Filename'] file_set = get_fold(datapoint_provenance, train_ratio, valid_ratio, test_only_projects) if file_set == 'train': train_graphs.append(datapoint) elif file_set == 'valid': valid_graphs.append(datapoint) elif file_set == 'test': test_graphs.append(datapoint) elif file_set == 'test-only': test_only_graphs.append(datapoint) except EOFError: print('Failed for file %s.' % input_path) return input_file_basename = input_path.basename() if train_ratio > 0: output_path = output_paths['train'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(train_graphs) if valid_ratio > 0: output_path = output_paths['valid'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(valid_graphs) if test_ratio > 0: output_path = output_paths['test'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(test_graphs) if len(test_only_graphs) > 0: output_path = output_paths['test-only'].join(input_file_basename) print('Saving %s...' % (output_path, )) output_path.save_as_compressed_file(test_only_graphs)
def __load_data(self, data_file: RichPath) -> List[GraphSampleType]: return [ self._process_raw_datapoint(datapoint) for datapoint in data_file.read_by_file_suffix() ]
def __load_data(self, data_file: RichPath) -> List[QM9GraphSample]: data = list(data_file.read_by_file_suffix( )) # list() needed for .jsonl case, where .read*() is just a generator return self.__process_raw_graphs(data)
def read_chunk(raw_data_chunk_path: RichPath): return raw_data_chunk_path.read_by_file_suffix()
def __load_data(self, data_file: RichPath) -> List[CodeGraphSample]: # read_by_file_suffix该函数可以读取npy、json、pkl和jsonl等多种数据类型的数据 data = list(data_file.read_by_file_suffix( )) # list() needed for .jsonl case, where .read*() is just a generator return self.__process_raw_graphs(data)