示例#1
0
 def metadata_parser_fn(
         _, file_path: RichPath) -> Iterable[Dict[str, Any]]:
     raw_metadata = {"type_lattice_path": type_lattice_path}
     type(self)._init_metadata(self.hyperparameters, raw_metadata)
     for raw_sample in file_path.read_by_file_suffix():
         type(self)._load_metadata_from_sample(
             self.hyperparameters,
             raw_sample=raw_sample,
             raw_metadata=raw_metadata)
     yield raw_metadata
def restore(path: RichPath,
            is_train: bool,
            hyper_overrides: Optional[Dict[str, Any]] = None,
            model_save_dir: Optional[str] = None,
            log_save_dir: Optional[str] = None) -> Model:
    saved_data = path.read_by_file_suffix()

    if hyper_overrides is not None:
        saved_data['hyperparameters'].update(hyper_overrides)

    model_class = get_model_class_from_name(saved_data['model_type'])
    model = model_class(hyperparameters=saved_data['hyperparameters'],
                        run_name=saved_data.get('run_name'),
                        model_save_dir=model_save_dir,
                        log_save_dir=log_save_dir)
    model.metadata.update(saved_data['metadata'])
    model.make_model(is_train=is_train)

    variables_to_initialize = []
    with model.sess.graph.as_default():
        with tf.name_scope("restore"):
            restore_ops = []
            used_vars = set()
            for variable in sorted(model.sess.graph.get_collection(
                    tf.GraphKeys.GLOBAL_VARIABLES),
                                   key=lambda v: v.name):
                used_vars.add(variable.name)
                if variable.name in saved_data['weights']:
                    # print('Initializing %s from saved value.' % variable.name)
                    restore_ops.append(
                        variable.assign(saved_data['weights'][variable.name]))
                else:
                    print(
                        'Freshly initializing %s since no saved value was found.'
                        % variable.name)
                    variables_to_initialize.append(variable)
            for var_name in sorted(saved_data['weights']):
                if var_name not in used_vars:
                    if var_name.endswith('Adam:0') or var_name.endswith(
                            'Adam_1:0') or var_name in [
                                'beta1_power:0', 'beta2_power:0'
                            ]:
                        continue
                    print('Saved weights for %s not used by model.' % var_name)
            restore_ops.append(
                tf.variables_initializer(variables_to_initialize))
            model.sess.run(restore_ops)
    return model
    def load_existing_metadata(self, metadata_path: RichPath):
        saved_data = metadata_path.read_by_file_suffix()

        hyper_names = set(self.hyperparameters.keys())
        hyper_names.update(saved_data['hyperparameters'].keys())
        if 'cg_node_type_vocab_size' in saved_data['hyperparameters']:
            self.hyperparameters['cg_node_type_vocab_size'] = saved_data['hyperparameters']['cg_node_type_vocab_size']  # TODO: Should not be needed
        for hyper_name in hyper_names:
            if hyper_name in ['run_id']:
                continue  # these are supposed to change
            old_hyper_value = saved_data['hyperparameters'].get(hyper_name)
            new_hyper_value = self.hyperparameters.get(hyper_name)
            if old_hyper_value != new_hyper_value:
                self.train_log("I: Hyperparameter %s now has value '%s' but was '%s' when tensorising data."
                               % (hyper_name, new_hyper_value, old_hyper_value))
        self.__metadata = saved_data['metadata']
def get_queries(query_file: RichPath,
                dfs: Dict[str, pd.DataFrame]) -> List[Dict[str, Any]]:
    query_descriptions = query_file.read_by_file_suffix()

    queries: List[Dict[str, Any]] = []
    for description in query_descriptions:
        on: List[MergeFrame] = []
        for frame, left, right in zip(description['frames'][1:],
                                      description['left'],
                                      description['right']):
            on.append(
                MergeFrame(frame=dfs[frame], left_on=left, right_on=right))

        query = {'from': dfs[description['frames'][0]], 'on': on}
        queries.append(query)

    return queries
示例#5
0
    def __load_data(self, data_file: RichPath) -> List[GraphSample]:
        print(" Loading QM9 data from %s." % (data_file, ))
        data = list(data_file.read_by_file_suffix(
        ))  # list() needed for .jsonl case, where .read*() is just a generator

        # Get some common data out:
        num_fwd_edge_types = 0
        for g in data:
            num_fwd_edge_types = max(num_fwd_edge_types,
                                     max([e[1] for e in g['graph']]))
        if self.params['add_self_loop_edges']:
            num_fwd_edge_types += 1
        self.__num_edge_types = max(
            self.num_edge_types,
            num_fwd_edge_types *
            (1 if self.params['tie_fwd_bkwd_edges'] else 2))
        self.__annotation_size = max(self.__annotation_size,
                                     len(data[0]["node_features"][0]))
        return self.__process_raw_graphs(data)
示例#6
0
def split_file(input_path: RichPath, output_paths: Dict[str, RichPath],
               train_ratio: float, valid_ratio: float, test_ratio: float,
               test_only_projects: Set[str]) -> None:
    train_graphs, valid_graphs, test_graphs, test_only_graphs = [], [], [], []

    try:
        for datapoint in input_path.read_by_file_suffix():
            datapoint_provenance = datapoint['Filename']
            file_set = get_fold(datapoint_provenance, train_ratio, valid_ratio,
                                test_only_projects)
            if file_set == 'train':
                train_graphs.append(datapoint)
            elif file_set == 'valid':
                valid_graphs.append(datapoint)
            elif file_set == 'test':
                test_graphs.append(datapoint)
            elif file_set == 'test-only':
                test_only_graphs.append(datapoint)
    except EOFError:
        print('Failed for file %s.' % input_path)
        return

    input_file_basename = input_path.basename()

    if train_ratio > 0:
        output_path = output_paths['train'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(train_graphs)

    if valid_ratio > 0:
        output_path = output_paths['valid'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(valid_graphs)

    if test_ratio > 0:
        output_path = output_paths['test'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(test_graphs)

    if len(test_only_graphs) > 0:
        output_path = output_paths['test-only'].join(input_file_basename)
        print('Saving %s...' % (output_path, ))
        output_path.save_as_compressed_file(test_only_graphs)
 def __load_data(self, data_file: RichPath) -> List[GraphSampleType]:
     return [
         self._process_raw_datapoint(datapoint)
         for datapoint in data_file.read_by_file_suffix()
     ]
示例#8
0
 def __load_data(self, data_file: RichPath) -> List[QM9GraphSample]:
     data = list(data_file.read_by_file_suffix(
     ))  # list() needed for .jsonl case, where .read*() is just a generator
     return self.__process_raw_graphs(data)
示例#9
0
 def read_chunk(raw_data_chunk_path: RichPath):
     return raw_data_chunk_path.read_by_file_suffix()
示例#10
0
 def __load_data(self, data_file: RichPath) -> List[CodeGraphSample]:
     # read_by_file_suffix该函数可以读取npy、json、pkl和jsonl等多种数据类型的数据
     data = list(data_file.read_by_file_suffix(
     ))  # list() needed for .jsonl case, where .read*() is just a generator
     return self.__process_raw_graphs(data)