def create( config: Config, dataset: Dataset, configuration_key: Optional[str] = None, init_for_load_only=False, create_embedders=True, parameter_client=None, max_partition_entities=0, ) -> "KgeModel": """Factory method for model creation.""" try: if configuration_key is not None: model_name = config.get(configuration_key + ".type") else: model_name = config.get("model") config._import(model_name) class_name = config.get(model_name + ".class_name") except: raise Exception("Can't find {}.type in config".format(configuration_key)) try: model = init_from( class_name, config.get("modules"), config=config, dataset=dataset, configuration_key=configuration_key, init_for_load_only=init_for_load_only, create_embedders=create_embedders, parameter_client=parameter_client, max_partition_entities=max_partition_entities, ) model.to(config.get("job.device")) return model except: config.log(f"Failed to create model {model_name} (class {class_name}).") raise
def __init__(self, config: Config, dataset: Dataset, parent_job: Job = None, model=None) -> None: from kge.job import EvaluationJob super().__init__(config, dataset, parent_job) if model is None: self.model: KgeModel = KgeModel.create(config, dataset) else: self.model: KgeModel = model self.optimizer = KgeOptimizer.create(config, self.model) self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") self._subbatch_auto_tune: bool = config.get("train.subbatch_auto_tune") self._max_subbatch_size: int = config.get("train.subbatch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get( "train.trace_level") == "batch" self.epoch: int = 0 self.valid_trace: List[Dict[str, Any]] = [] valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": valid_conf.set("eval.split", self.config.get("valid.split")) valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) self.valid_job = EvaluationJob.create(valid_conf, dataset, parent_job=self, model=self.model) # attributes filled in by implementing classes self.loader = None self.num_examples = None self.type_str: Optional[str] = None # Hooks run after validation. The corresponding valid trace entry can be found # in self.valid_trace[-1] Signature: job self.post_valid_hooks: List[Callable[[Job], Any]] = [] if self.__class__ == TrainingJob: for f in Job.job_created_hooks: f(self) self.model.train()
def create_from(cls, checkpoint: Dict, new_config: Config = None, dataset: Dataset = None, parent_job=None, parameter_client=None) -> Job: """ Creates a Job based on a checkpoint Args: checkpoint: loaded checkpoint new_config: optional config object - overwrites options of config stored in checkpoint dataset: dataset object parent_job: parent job (e.g. search job) Returns: Job based on checkpoint """ from kge.model import KgeModel model: KgeModel = None # search jobs don't have a model if "model" in checkpoint and checkpoint["model"] is not None: model = KgeModel.create_from(checkpoint, new_config=new_config, dataset=dataset, parameter_client=parameter_client) config = model.config dataset = model.dataset else: config = Config.create_from(checkpoint) if new_config: config.load_config(new_config) dataset = Dataset.create_from(checkpoint, config, dataset) job = Job.create(config, dataset, parent_job, model, parameter_client=parameter_client, init_for_load_only=True) job._load(checkpoint) job.config.log("Loaded checkpoint from {}...".format( checkpoint["file"])) return job
def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job) self.config = config self.dataset = dataset self.model = model self.batch_size = config.get("eval.batch_size") self.device = self.config.get("job.device") self.config.check("train.trace_level", ["example", "batch", "epoch"]) self.trace_examples = self.config.get("eval.trace_level") == "example" self.trace_batch = (self.trace_examples or self.config.get("train.trace_level") == "batch") self.eval_split = self.config.get("eval.split") self.epoch = -1 # all done, run job_created_hooks if necessary if self.__class__ == EvaluationJob: for f in Job.job_created_hooks: f(self)
def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job, model) training_loss_eval_config = config.clone() # TODO set train split to include validation data here # once support is added # Then reflect this change in the trace entries self._train_job = TrainingJob.create( config=training_loss_eval_config, parent_job=self, dataset=dataset, model=model, forward_only=True, ) if self.__class__ == TrainingLossEvaluationJob: for f in Job.job_created_hooks: f(self)
def __init__(self, config: Config, dataset: Dataset, parent_job, model): super().__init__(config, dataset, parent_job, model) self.config.check( "entity_ranking.tie_handling", ["rounded_mean_rank", "best_rank", "worst_rank"], ) self.tie_handling = self.config.get("entity_ranking.tie_handling") if self.__class__ == EntityRankingJob: for f in Job.job_created_hooks: f(self) max_k = min( self.dataset.num_entities(), max(self.config.get("entity_ranking.hits_at_k_s")), ) self.hits_at_k_s = list( filter(lambda x: x <= max_k, self.config.get("entity_ranking.hits_at_k_s"))) self.filter_with_test = config.get("entity_ranking.filter_with_test")
def create_from( checkpoint: Dict, dataset: Optional[Dataset] = None, use_tmp_log_folder=True, new_config: Config = None, ) -> "KgeModel": """Loads a model from a checkpoint file of a training job or a packaged model. If dataset is specified, associates this dataset with the model. Otherwise uses the dataset used to train the model. If `use_tmp_log_folder` is set, the logs and traces are written to a temporary file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or appended to) in the checkpoint's folder. """ config = Config.create_from(checkpoint) if new_config: config.load_config(new_config) if use_tmp_log_folder: import tempfile config.log_folder = tempfile.mkdtemp(prefix="kge-") else: config.log_folder = checkpoint["folder"] if not config.log_folder or not os.path.exists(config.log_folder): config.log_folder = "." dataset = Dataset.create_from(checkpoint, config, dataset, preload_data=False) model = KgeModel.create(config, dataset, init_for_load_only=True) model.load(checkpoint["model"]) model.eval() return model
def __init__( self, config: Config, dataset: Dataset, configuration_key: str, vocab_size: int, init_for_load_only=False, ): super().__init__(config, dataset, configuration_key, init_for_load_only=init_for_load_only) # read config self.regularize = self.check_option("regularize", ["", "lp"]) self.config.check("train.trace_level", ["batch", "epoch"]) self.vocab_size = vocab_size self.filename = self.get_option("filename") self.num_layers = self.get_option("num_layers") # load numeric data with open(self.filename, "r") as f: data = list(map(lambda s: s.strip().split("\t"), f.readlines())) # returns entities in index order entities = self.dataset.entity_ids() ent_to_idx = {ent: idx for idx, ent in enumerate(entities)} numeric_data_ent_idx = [] rel_to_idx = {} numeric_data_rel_idx = [] numeric_data = [] for t in data: ent = t[0] rel = t[1] value = float(t[2]) if rel not in rel_to_idx: rel_to_idx[rel] = len(rel_to_idx) numeric_data_ent_idx.append(ent_to_idx[ent]) numeric_data_rel_idx.append(rel_to_idx[rel]) numeric_data.append(value) numeric_data_ent_idx = torch.tensor(numeric_data_ent_idx, dtype=torch.long) numeric_data_rel_idx = torch.tensor(numeric_data_rel_idx, dtype=torch.long) numeric_data = torch.tensor(numeric_data, dtype=torch.float32) # normalize numeric literals if self.get_option("normalization") == "min-max": for rel_idx in rel_to_idx.values(): sel = (rel_idx == numeric_data_rel_idx) max_num = torch.max(numeric_data[sel]) min_num = torch.min(numeric_data[sel]) numeric_data[sel] = ((numeric_data[sel] - min_num) / (max_num - min_num + 1e-8)) elif self.get_option("normalization") == "z-score": for rel_idx in rel_to_idx.values(): sel = (rel_idx == numeric_data_rel_idx) mean = torch.mean(numeric_data[sel]) # account for the fact that there might only be a single value # in that case torch.std would result in nan if torch.sum(sel) > 1: std = torch.std(numeric_data[sel]) else: std = 0 numeric_data[sel] = ((numeric_data[sel] - mean) / (std + 1e-8)) else: raise ValueError("Unkown normalization option") num_lit = torch.zeros( [len(ent_to_idx), len(rel_to_idx)], dtype=torch.float32) num_lit[numeric_data_ent_idx, numeric_data_rel_idx] = numeric_data # includes all numeric literals for all entities, with the entities # being ordered by their index self.num_lit = num_lit.to(self.config.get("job.device")) if self.num_layers > 0: # initialize numeric MLP self.numeric_mlp = NumericMLP( input_dim=num_lit.shape[1], output_dim=self.dim, num_layers=self.num_layers, activation=self.get_option("activation")) if not init_for_load_only: # initialize weights for name, weights in self.numeric_mlp.named_parameters(): # set bias to zero # https://cs231n.github.io/neural-networks-2/#init if "bias" in name: torch.nn.init.zeros_(weights) else: self.initialize(weights) else: self.dim = num_lit.shape[1] # TODO handling negative dropout because using it with ax searches for now dropout = self.get_option("dropout") if dropout < 0: if config.get("train.auto_correct"): config.log("Setting {}.dropout to 0, " "was set to {}.".format(configuration_key, dropout)) dropout = 0 self.dropout = torch.nn.Dropout(dropout)
def __init__( self, config: Config, dataset: Dataset, configuration_key: str, vocab_size: int, parameter_client: "KgeParameterClient", complete_vocab_size, lapse_offset=0, init_for_load_only=False, ): super().__init__( config, dataset, configuration_key, vocab_size, init_for_load_only=init_for_load_only, ) self.optimizer_dim = get_optimizer_dim(config, self.dim) self.optimizer_values = torch.zeros( (self.vocab_size, self.optimizer_dim), dtype=torch.float32, requires_grad=False, ) self.complete_vocab_size = complete_vocab_size self.parameter_client = parameter_client self.lapse_offset = lapse_offset self.pulled_ids = None self.load_batch = self.config.get("job.distributed.load_batch") # global to local mapper only used in sync level partition self.global_to_local_mapper = torch.full( (self.dataset.num_entities(), ), -1, dtype=torch.long, device="cpu") # maps the local embeddings to the embeddings in lapse # used in optimizer self.local_to_lapse_mapper = torch.full((vocab_size, ), -1, dtype=torch.long, requires_grad=False) self.pull_dim = self.dim + self.optimizer_dim self.unnecessary_dim = self.parameter_client.dim - self.pull_dim # 3 pull tensors to pre-pull up to 3 batches # first boolean denotes if the tensor is free self.pull_tensors = [ [ True, torch.empty( (self.vocab_size, self.parameter_client.dim), # (self.vocab_size, self.dim + self.optimizer_dim), dtype=torch.float32, device="cpu", requires_grad=False, ), ], [ True, torch.empty( (self.vocab_size, self.parameter_client.dim), # (self.vocab_size, self.dim + self.optimizer_dim), dtype=torch.float32, device="cpu", requires_grad=False, ), ], [ True, torch.empty( (self.vocab_size, self.parameter_client.dim), # (self.vocab_size, self.dim + self.optimizer_dim), dtype=torch.float32, device="cpu", requires_grad=False, ), ], ] if "cuda" in config.get("job.device"): # only pin tensors if we are using gpu # otherwise gpu memory will be allocated for no reason with torch.cuda.device(config.get("job.device")): for i in range(len(self.pull_tensors)): self.pull_tensors[i][1] = self.pull_tensors[i][ 1].pin_memory() self.num_pulled = 0 self.mapping_time = 0.0 # self.pre_pulled = None self.pre_pulled = deque()
def __init__( self, config: Config, dataset: Dataset, scorer: Union[RelationalScorer, type], create_embedders=True, configuration_key=None, init_for_load_only=False, parameter_client=None, max_partition_entities=0, ): super().__init__(config, dataset, configuration_key) # TODO support different embedders for subjects and objects #: Embedder used for entities (both subject and objects) self._entity_embedder: KgeEmbedder #: Embedder used for relations self._relation_embedder: KgeEmbedder if create_embedders: self._create_embedders(init_for_load_only) elif False: #if self.get_option("create_complete"): # embedding_layer_size = dataset.num_entities() if config.get("job.distributed.entity_sync_level") == "partition" and max_partition_entities != 0: embedding_layer_size =max_partition_entities else: embedding_layer_size = self._calc_embedding_layer_size(config, dataset) config.log(f"creating entity_embedder with {embedding_layer_size} keys") self._entity_embedder = KgeEmbedder.create( config=config, dataset=dataset, configuration_key=self.configuration_key + ".entity_embedder", #dataset.num_entities(), vocab_size=embedding_layer_size, init_for_load_only=init_for_load_only, parameter_client=parameter_client, lapse_offset=0, complete_vocab_size=dataset.num_entities() ) #: Embedder used for relations num_relations = dataset.num_relations() self._relation_embedder = KgeEmbedder.create( config, dataset, self.configuration_key + ".relation_embedder", num_relations, init_for_load_only=init_for_load_only, parameter_client=parameter_client, lapse_offset=dataset.num_entities(), complete_vocab_size=dataset.num_relations(), ) if not init_for_load_only and parameter_client.rank == get_min_rank(config): # load pretrained embeddings pretrained_entities_filename = "" pretrained_relations_filename = "" if self.has_option("entity_embedder.pretrain.model_filename"): pretrained_entities_filename = self.get_option( "entity_embedder.pretrain.model_filename" ) if self.has_option("relation_embedder.pretrain.model_filename"): pretrained_relations_filename = self.get_option( "relation_embedder.pretrain.model_filename" ) def load_pretrained_model( pretrained_filename: str, ) -> Optional[KgeModel]: if pretrained_filename != "": self.config.log( f"Initializing with embeddings stored in " f"{pretrained_filename}" ) checkpoint = load_checkpoint(pretrained_filename) return KgeModel.create_from(checkpoint, parameter_client=parameter_client) return None pretrained_entities_model = load_pretrained_model( pretrained_entities_filename ) if pretrained_entities_filename == pretrained_relations_filename: pretrained_relations_model = pretrained_entities_model else: pretrained_relations_model = load_pretrained_model( pretrained_relations_filename ) if pretrained_entities_model is not None: if ( pretrained_entities_model.get_s_embedder() != pretrained_entities_model.get_o_embedder() ): raise ValueError( "Can only initialize with pre-trained models having " "identical subject and object embeddings." ) self._entity_embedder.init_pretrained( pretrained_entities_model.get_s_embedder() ) if pretrained_relations_model is not None: self._relation_embedder.init_pretrained( pretrained_relations_model.get_p_embedder() ) #: Scorer self._scorer: RelationalScorer if type(scorer) == type: # scorer is type of the scorer to use; call its constructor self._scorer = scorer( config=config, dataset=dataset, configuration_key=self.configuration_key ) else: self._scorer = scorer
def create_default( model: Optional[str] = None, dataset: Optional[Union[Dataset, str]] = None, options: Dict[str, Any] = {}, folder: Optional[str] = None, ) -> "KgeModel": """Utility method to create a model, including configuration and dataset. `model` is the name of the model (takes precedence over ``options["model"]``), `dataset` a dataset name or `Dataset` instance (takes precedence over ``options["dataset.name"]``), and options arbitrary other configuration options. If `folder` is ``None``, creates a temporary folder. Otherwise uses the specified folder. """ # load default model config if model is None: model = options["model"] default_config_file = filename_in_module(kge.model, "{}.yaml".format(model)) config = Config() config.load(default_config_file, create=True) # apply specified options config.set("model", model) if isinstance(dataset, Dataset): config.set("dataset.name", dataset.config.get("dataset.name")) elif isinstance(dataset, str): config.set("dataset.name", dataset) config.set_all(new_options=options) # create output folder if folder is None: config.folder = tempfile.mkdtemp( "{}-{}-".format(config.get("dataset.name"), config.get("model")) ) else: config.folder = folder # create dataset and model if not isinstance(dataset, Dataset): dataset = Dataset.create(config) model = KgeModel.create(config, dataset) return model
def _dump_config(args): """Execute the 'dump config' command.""" if not (args.raw or args.full or args.minimal): args.minimal = True if args.raw + args.full + args.minimal != 1: raise ValueError( "Exactly one of --raw, --full, or --minimal must be set") if args.raw and (args.include or args.exclude): raise ValueError("--include and --exclude cannot be used with --raw " "(use --full or --minimal instead).") config = Config() config_file = None if os.path.isdir(args.source): config_file = os.path.join(args.source, "config.yaml") config.load(config_file) elif ".yaml" in os.path.split(args.source)[-1]: config_file = args.source config.load(config_file) else: # a checkpoint checkpoint = torch.load(args.source, map_location="cpu") if args.raw: config = checkpoint["config"] else: config.load_config(checkpoint["config"]) def print_options(options): # drop all arguments that are not included if args.include: args.include = set(args.include) options_copy = copy.deepcopy(options) for key in options_copy.keys(): prefix = key keep = False while True: if prefix in args.include: keep = True break else: last_dot_index = prefix.rfind(".") if last_dot_index < 0: break else: prefix = prefix[:last_dot_index] if not keep: del options[key] # remove all arguments that are excluded if args.exclude: args.exclude = set(args.exclude) options_copy = copy.deepcopy(options) for key in options_copy.keys(): prefix = key while True: if prefix in args.exclude: del options[key] break else: last_dot_index = prefix.rfind(".") if last_dot_index < 0: break else: prefix = prefix[:last_dot_index] # convert the remaining options to a Config and print it config = Config(load_default=False) config.set_all(options, create=True) print(yaml.dump(config.options)) if args.raw: if config_file: with open(config_file, "r") as f: print(f.read()) else: print_options(config.options) elif args.full: print_options(config.options) else: # minimal default_config = Config() imports = config.get("import") if imports is not None: if not isinstance(imports, list): imports = [imports] for module_name in imports: default_config._import(module_name) default_options = Config.flatten(default_config.options) new_options = Config.flatten(config.options) minimal_options = {} for option, value in new_options.items(): if option not in default_options or default_options[ option] != value: minimal_options[option] = value # always retain all imports if imports is not None: minimal_options["import"] = list(set(imports)) print_options(minimal_options)
def create_config(test_dataset_name: str, model: str = "complex") -> Config: config = Config() config.folder = None config.set("console.quiet", True) config.set("model", model) config._import(model) config.set("dataset.name", test_dataset_name) config.set("job.device", "cpu") return config
def _dump_trace(args): """ Executes the 'dump trace' command.""" start = time.time() if (args.train or args.valid or args.test) and args.search: print( "--search and --train, --valid, --test are mutually exclusive", file=sys.stderr, ) exit(1) entry_type_specified = True if not (args.train or args.valid or args.test or args.search): entry_type_specified = False args.train = True args.valid = True args.test = True checkpoint_path = None if ".pt" in os.path.split(args.source)[-1]: checkpoint_path = args.source folder_path = os.path.split(args.source)[0] else: # determine job_id and epoch from last/best checkpoint automatically if args.checkpoint: checkpoint_path = Config.get_best_or_last_checkpoint(args.source) folder_path = args.source if not args.checkpoint and args.truncate: raise ValueError( "You can only use --truncate when a checkpoint is specified." "Consider using --checkpoint or provide a checkpoint file as source" ) trace = os.path.join(folder_path, "trace.yaml") if not os.path.isfile(trace): sys.stderr.write("No trace found at {}\n".format(trace)) exit(1) keymap = OrderedDict() additional_keys = [] if args.keysfile: with open(args.keysfile, "r") as keyfile: additional_keys = keyfile.readlines() if args.keys: additional_keys += args.keys for line in additional_keys: line = line.rstrip("\n").replace(" ", "") name_key = line.split("=") if len(name_key) == 1: name_key += name_key keymap[name_key[0]] = name_key[1] job_id = None epoch = int(args.max_epoch) # use job_id and epoch from checkpoint if checkpoint_path and args.truncate: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] epoch = checkpoint["epoch"] # only use job_id from checkpoint elif checkpoint_path: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] # override job_id and epoch with user arguments if args.job_id: job_id = args.job_id if not epoch: epoch = float("inf") entries, job_epochs = [], {} if not args.search: entries, job_epochs = Trace.grep_training_trace_entries( tracefile=trace, train=args.train, test=args.test, valid=args.valid, example=args.example, batch=args.batch, job_id=job_id, epoch_of_last=epoch, ) if not entries and (args.search or not entry_type_specified): entries = Trace.grep_entries(tracefile=trace, conjunctions=[f"scope: train"]) epoch = None if entries: args.search = True if not entries: print("No relevant trace entries found.", file=sys.stderr) exit(1) middle = time.time() if not args.yaml: csv_writer = csv.writer(sys.stdout) # dict[new_name] = (lookup_name, where) # if where=="config"/"trace" it will be looked up automatically # if where=="sep" it must be added in in the write loop separately if args.no_default_keys: default_attributes = OrderedDict() else: default_attributes = OrderedDict([ ("job_id", ("job_id", "sep")), ("dataset", ("dataset.name", "config")), ("model", ("model", "sep")), ("reciprocal", ("reciprocal", "sep")), ("job", ("job", "sep")), ("job_type", ("type", "trace")), ("split", ("split", "sep")), ("epoch", ("epoch", "trace")), ("avg_loss", ("avg_loss", "trace")), ("avg_penalty", ("avg_penalty", "trace")), ("avg_cost", ("avg_cost", "trace")), ("metric_name", ("valid.metric", "config")), ("metric", ("metric", "sep")), ]) if args.search: default_attributes["child_folder"] = ("folder", "trace") default_attributes["child_job_id"] = ("child_job_id", "sep") if not args.no_header: csv_writer.writerow( list(default_attributes.keys()) + [key for key in keymap.keys()]) # store configs for job_id's s.t. they need to be loaded only once configs = {} warning_shown = False for entry in entries: if epoch and not entry.get("epoch") <= float(epoch): continue # filter out not needed entries from a previous job when # a job was resumed from the middle if entry.get("job") == "train": job_id = entry.get("job_id") if entry.get("epoch") > job_epochs[job_id]: continue # find relevant config file child_job_id = entry.get( "child_job_id") if "child_job_id" in entry else None config_key = (entry.get("folder") + "/" + str(child_job_id) if args.search else entry.get("job_id")) if config_key in configs.keys(): config = configs[config_key] else: if args.search: if not child_job_id and not warning_shown: # This warning is from Dec 19, 2019. TODO remove print( "Warning: You are dumping the trace of an older search job. " "This is fine only if " "the config.yaml files in each subfolder have not been modified " "after running the corresponding training job.", file=sys.stderr, ) warning_shown = True config = get_config_for_job_id( child_job_id, os.path.join(folder_path, entry.get("folder"))) entry["type"] = config.get("train.type") else: config = get_config_for_job_id(entry.get("job_id"), folder_path) configs[config_key] = config new_attributes = OrderedDict() if config.get_default("model") == "reciprocal_relations_model": model = config.get_default( "reciprocal_relations_model.base_model.type") # the string that substitutes $base_model in keymap if it exists subs_model = "reciprocal_relations_model.base_model" reciprocal = 1 else: model = config.get_default("model") subs_model = model reciprocal = 0 for new_key in keymap.keys(): lookup = keymap[new_key] if "$base_model" in lookup: lookup = lookup.replace("$base_model", subs_model) try: if lookup == "$folder": val = os.path.abspath(folder_path) elif lookup == "$checkpoint": val = os.path.abspath(checkpoint_path) elif lookup == "$machine": val = socket.gethostname() else: val = config.get_default(lookup) except: # creates empty field if key is not existing val = entry.get(lookup) if type(val) == bool and val: val = 1 elif type(val) == bool and not val: val = 0 new_attributes[new_key] = val if not args.yaml: # find the actual values for the default attributes actual_default = default_attributes.copy() for new_key in default_attributes.keys(): lookup, where = default_attributes[new_key] if where == "config": actual_default[new_key] = config.get(lookup) elif where == "trace": actual_default[new_key] = entry.get(lookup) # keys with separate treatment # "split" in {train,test,valid} for the datatype # "job" in {train,eval,valid,search} if entry.get("job") == "train": actual_default["split"] = "train" actual_default["job"] = "train" elif entry.get("job") == "eval": actual_default["split"] = entry.get("data") # test or valid if entry.get("resumed_from_job_id"): actual_default["job"] = "eval" # from "kge eval" else: actual_default["job"] = "valid" # child of training job else: actual_default["job"] = entry.get("job") actual_default["split"] = entry.get("data") actual_default["job_id"] = entry.get("job_id").split("-")[0] actual_default["model"] = model actual_default["reciprocal"] = reciprocal # lookup name is in config value is in trace actual_default["metric"] = entry.get( config.get_default("valid.metric")) if args.search: actual_default["child_job_id"] = entry.get( "child_job_id").split("-")[0] for key in list(actual_default.keys()): if key not in default_attributes: del actual_default[key] csv_writer.writerow( [actual_default[new_key] for new_key in actual_default.keys()] + [new_attributes[new_key] for new_key in new_attributes.keys()]) else: entry.update({"reciprocal": reciprocal, "model": model}) if keymap: entry.update(new_attributes) sys.stdout.write(re.sub("[{}']", "", str(entry))) sys.stdout.write("\n") end = time.time() if args.timeit: sys.stdout.write("Grep + processing took {} \n".format(middle - start)) sys.stdout.write("Writing took {}".format(end - middle))
def run(self): # read search configurations and expand them to full configs search_configs = copy.deepcopy( self.config.get("manual_search.configurations")) all_keys = set() for i in range(len(search_configs)): search_config = search_configs[i] folder = search_config["folder"] del search_config["folder"] config = self.config.clone(folder) config.set("job.type", "train") config.options.pop("manual_search", None) # could be large, don't copy flattened_search_config = Config.flatten(search_config) config.set_all(flattened_search_config) all_keys.update(flattened_search_config.keys()) search_configs[i] = config # create folders for search configs (existing folders remain # unmodified) for config in search_configs: config.init_folder() # TODO find a way to create all indexes before running the jobs. The quick hack # below does not work becuase pytorch then throws a "too many open files" error # self.dataset.index("train_sp_to_o") # self.dataset.index("train_po_to_s") # self.dataset.index("valid_sp_to_o") # self.dataset.index("valid_po_to_s") # self.dataset.index("test_sp_to_o") # self.dataset.index("test_po_to_s") # now start running/resuming for i, config in enumerate(search_configs): task_arg = (self, i, config, len(search_configs), all_keys) self.submit_task(kge.job.search._run_train_job, task_arg) self.wait_task(concurrent.futures.ALL_COMPLETED) # if not running the jobs, stop here if not self.config.get("manual_search.run"): self.config.log( "Skipping evaluation of results as requested by user.") return # collect results best_per_job = [None] * len(search_configs) best_metric_per_job = [None] * len(search_configs) for ibm in self.ready_task_results: i, best, best_metric = ibm best_per_job[i] = best best_metric_per_job[i] = best_metric # produce an overall summary self.config.log("Result summary:") metric_name = self.config.get("valid.metric") overall_best = None overall_best_metric = None for i in range(len(search_configs)): best = best_per_job[i] best_metric = best_metric_per_job[i] if not overall_best or overall_best_metric < best_metric: overall_best = best overall_best_metric = best_metric self.config.log( "{}={:.3f} after {} epochs in folder {}".format( metric_name, best_metric, best["epoch"], best["folder"]), prefix=" ", ) self.config.log("And the winner is:") self.config.log( "{}={:.3f} after {} epochs in folder {}".format( metric_name, overall_best_metric, overall_best["epoch"], overall_best["folder"], ), prefix=" ", ) self.config.log("Best overall result:") self.trace(event="search_completed", echo=True, echo_prefix=" ", log=True, scope="search", **overall_best)
def create(config: Config): """Factory method for loss function instantiation.""" # perhaps TODO: try class with specified name -> extensibility config.check( "train.loss", [ "bce", "bce_mean", "bce_self_adversarial", "margin_ranking", "ce", "kl", "soft_margin", ], ) if config.get("train.loss") == "bce": offset = config.get("train.loss_arg") if math.isnan(offset): offset = 0.0 config.set("train.loss_arg", offset, log=True) return BCEWithLogitsKgeLoss(config, offset=offset, bce_type=None) elif config.get("train.loss") == "bce_mean": offset = config.get("train.loss_arg") if math.isnan(offset): offset = 0.0 config.set("train.loss_arg", offset, log=True) return BCEWithLogitsKgeLoss(config, offset=offset, bce_type="mean") elif config.get("train.loss") == "bce_self_adversarial": offset = config.get("train.loss_arg") if math.isnan(offset): offset = 0.0 config.set("train.loss_arg", offset, log=True) try: temperature = float( config.get("user.bce_self_adversarial_temperature")) except KeyError: temperature = 1.0 config.log(f"Using adversarial temperature {temperature}") return BCEWithLogitsKgeLoss( config, offset=offset, bce_type="self_adversarial", temperature=temperature, ) elif config.get("train.loss") == "kl": return KLDivWithSoftmaxKgeLoss(config) elif config.get("train.loss") == "margin_ranking": margin = config.get("train.loss_arg") if math.isnan(margin): margin = 1.0 config.set("train.loss_arg", margin, log=True) return MarginRankingKgeLoss(config, margin=margin) elif config.get("train.loss") == "soft_margin": return SoftMarginKgeLoss(config) else: raise ValueError("invalid value train.loss={}".format( config.get("train.loss")))
def __init__( self, config: Config, dataset: Dataset, parent_job: Job = None, model=None, optimizer=None, forward_only=False, parameter_client=None, ) -> None: from kge.job import EvaluationJob super().__init__(config, dataset, parent_job, parameter_client=parameter_client) if model is None: self.model: KgeModel = KgeModel.create( config, dataset, ) else: self.model: KgeModel = model self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") self._subbatch_auto_tune: bool = config.get("train.subbatch_auto_tune") self._max_subbatch_size: int = config.get("train.subbatch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get( "train.trace_level") == "batch" self.epoch: int = 0 self.is_forward_only = forward_only if not self.is_forward_only: self.model.train() if optimizer is None: self.optimizer = KgeOptimizer.create( config, self.model, ) else: self.optimizer = optimizer self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self._lr_warmup = self.config.get("train.lr_warmup") for group in self.optimizer.param_groups: group["initial_lr"] = group["lr"] self.valid_trace: List[Dict[str, Any]] = [] valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": valid_conf.set("eval.split", self.config.get("valid.split")) valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) self.valid_job = EvaluationJob.create(valid_conf, dataset, parent_job=self, model=self.model) # attributes filled in by implementing classes self.loader = None self.num_examples = None self.type_str: Optional[str] = None # Hooks run after validation. The corresponding valid trace entry can be found # in self.valid_trace[-1] Signature: job self.post_valid_hooks: List[Callable[[Job], Any]] = [] # Hooks run on early stopping self.early_stop_hooks: List[Callable[[Job], Any]] = [] # Hooks to add conditions to stop early # The hooked function needs to return a boolean self.early_stop_conditions: List[Callable[[Job], Any]] = [] if self.__class__ == TrainingJob: for f in Job.job_created_hooks: f(self)
def _dump_trace(args): """Execute the 'dump trace' command.""" if (args.train or args.valid or args.test or args.truncate or args.job_id or args.checkpoint or args.batch or args.example) and args.search: sys.exit( "--search and any of --train, --valid, --test, --truncate, --job_id," " --checkpoint, --batch, --example are mutually exclusive") entry_type_specified = True if not (args.train or args.valid or args.test or args.search): entry_type_specified = False args.train = True args.valid = True args.test = True truncate_flag = False truncate_epoch = None if isinstance(args.truncate, bool) and args.truncate: truncate_flag = True elif not isinstance(args.truncate, bool): if not args.truncate.isdigit(): sys.exit( "Integer argument or no argument for --truncate must be used") truncate_epoch = int(args.truncate) checkpoint_path = None if ".pt" in os.path.split(args.source)[-1]: checkpoint_path = args.source folder_path = os.path.split(args.source)[0] else: # determine job_id and epoch from last/best checkpoint automatically if args.checkpoint: checkpoint_path = Config.best_or_last_checkpoint_file(args.source) folder_path = args.source if not checkpoint_path and truncate_flag: sys.exit( "--truncate can only be used as a flag when a checkpoint is specified." " Consider specifying a checkpoint or use an integer argument for the" " --truncate option") if checkpoint_path and args.job_id: sys.exit( "--job_id cannot be used together with a checkpoint as the checkpoint" " already specifies the job_id") trace = os.path.join(folder_path, "trace.yaml") if not os.path.isfile(trace): sys.exit( f"No file 'trace.yaml' found at {os.path.abspath(folder_path)}") # process additional keys from --keys and --keysfile keymap = OrderedDict() additional_keys = [] if args.keysfile: with open(args.keysfile, "r") as keyfile: additional_keys = keyfile.readlines() if args.keys: additional_keys += args.keys for line in additional_keys: line = line.rstrip("\n").replace(" ", "") name_key = line.split("=") if len(name_key) == 1: name_key += name_key keymap[name_key[0]] = name_key[1] job_id = None # use job_id and truncate_epoch from checkpoint if checkpoint_path and truncate_flag: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] truncate_epoch = checkpoint["epoch"] # only use job_id from checkpoint elif checkpoint_path: checkpoint = torch.load(f=checkpoint_path, map_location="cpu") job_id = checkpoint["job_id"] # no checkpoint specified job_id might have been set manually elif args.job_id: job_id = args.job_id # don't restrict epoch number in case it has not been specified yet if not truncate_epoch: truncate_epoch = float("inf") entries, job_epochs = [], {} if not args.search: entries, job_epochs = Trace.grep_training_trace_entries( tracefile=trace, train=args.train, test=args.test, valid=args.valid, example=args.example, batch=args.batch, job_id=job_id, epoch_of_last=truncate_epoch, ) if not entries and (args.search or not entry_type_specified): entries = Trace.grep_entries(tracefile=trace, conjunctions=[f"scope: train"]) truncate_epoch = None if entries: args.search = True if not entries and entry_type_specified: sys.exit( "No relevant trace entries found. If this was a trace from a search" " job, dont use any of --train --valid --test.") elif not entries: sys.exit("No relevant trace entries found.") if args.list_keys: all_trace_keys = set() if not args.yaml: csv_writer = csv.writer(sys.stdout) # dict[new_name] = (lookup_name, where) # if where=="config"/"trace" it will be looked up automatically # if where=="sep" it must be added in in the write loop separately if args.no_default_keys: default_attributes = OrderedDict() else: default_attributes = OrderedDict([ ("job_id", ("job_id", "sep")), ("dataset", ("dataset.name", "config")), ("model", ("model", "sep")), ("reciprocal", ("reciprocal", "sep")), ("job", ("job", "sep")), ("job_type", ("type", "trace")), ("split", ("split", "sep")), ("epoch", ("epoch", "trace")), ("avg_loss", ("avg_loss", "trace")), ("avg_penalty", ("avg_penalty", "trace")), ("avg_cost", ("avg_cost", "trace")), ("metric_name", ("valid.metric", "config")), ("metric", ("metric", "sep")), ]) if args.search: default_attributes["child_folder"] = ("folder", "trace") default_attributes["child_job_id"] = ("child_job_id", "sep") if not (args.no_header or args.list_keys): csv_writer.writerow( list(default_attributes.keys()) + [key for key in keymap.keys()]) # store configs for job_id's s.t. they need to be loaded only once configs = {} warning_shown = False for entry in entries: current_epoch = entry.get("epoch") job_type = entry.get("job") job_id = entry.get("job_id") if truncate_epoch and not current_epoch <= float(truncate_epoch): continue # filter out entries not relevant to the unique training sequence determined # by the options; not relevant for search if job_type == "train": if current_epoch > job_epochs[job_id]: continue elif job_type == "eval": if "resumed_from_job_id" in entry: if current_epoch > job_epochs[entry.get( "resumed_from_job_id")]: continue elif "parent_job_id" in entry: if current_epoch > job_epochs[entry.get("parent_job_id")]: continue # find relevant config file child_job_id = entry.get( "child_job_id") if "child_job_id" in entry else None config_key = (entry.get("folder") + "/" + str(child_job_id) if args.search else job_id) if config_key in configs.keys(): config = configs[config_key] else: if args.search: if not child_job_id and not warning_shown: # This warning is from Dec 19, 2019. TODO remove print( "Warning: You are dumping the trace of an older search job. " "This is fine only if " "the config.yaml files in each subfolder have not been modified " "after running the corresponding training job.", file=sys.stderr, ) warning_shown = True config = get_config_for_job_id( child_job_id, os.path.join(folder_path, entry.get("folder"))) entry["type"] = config.get("train.type") else: config = get_config_for_job_id(job_id, folder_path) configs[config_key] = config if args.list_keys: all_trace_keys.update(entry.keys()) continue new_attributes = OrderedDict() # when training was reciprocal, use the base_model as model if config.get_default("model") == "reciprocal_relations_model": model = config.get_default( "reciprocal_relations_model.base_model.type") # the string that substitutes $base_model in keymap if it exists subs_model = "reciprocal_relations_model.base_model" reciprocal = 1 else: model = config.get_default("model") subs_model = model reciprocal = 0 # search for the additional keys from --keys and --keysfile for new_key in keymap.keys(): lookup = keymap[new_key] # search for special keys value = None if lookup == "$folder": value = os.path.abspath(folder_path) elif lookup == "$checkpoint" and checkpoint_path: value = os.path.abspath(checkpoint_path) elif lookup == "$machine": value = socket.gethostname() if "$base_model" in lookup: lookup = lookup.replace("$base_model", subs_model) # search for ordinary keys; start searching in trace entry then config if not value: value = entry.get(lookup) if not value: try: value = config.get_default(lookup) except: pass # value stays None; creates empty field in csv if value and isinstance(value, bool): value = 1 elif not value and isinstance(value, bool): value = 0 new_attributes[new_key] = value if not args.yaml: # find the actual values for the default attributes actual_default = default_attributes.copy() for new_key in default_attributes.keys(): lookup, where = default_attributes[new_key] if where == "config": actual_default[new_key] = config.get(lookup) elif where == "trace": actual_default[new_key] = entry.get(lookup) # keys with separate treatment # "split" in {train,test,valid} for the datatype # "job" in {train,eval,valid,search} if job_type == "train": if "split" in entry: actual_default["split"] = entry.get("split") else: actual_default["split"] = "train" actual_default["job"] = "train" elif job_type == "eval": if "split" in entry: actual_default["split"] = entry.get( "split") # test or valid else: # deprecated actual_default["split"] = entry.get( "data") # test or valid if entry.get("resumed_from_job_id"): actual_default["job"] = "eval" # from "kge eval" else: actual_default["job"] = "valid" # child of training job else: actual_default["job"] = job_type if "split" in entry: actual_default["split"] = entry.get("split") else: # deprecated actual_default["split"] = entry.get( "data") # test or valid actual_default["job_id"] = job_id.split("-")[0] actual_default["model"] = model actual_default["reciprocal"] = reciprocal # lookup name is in config value is in trace actual_default["metric"] = entry.get( config.get_default("valid.metric")) if args.search: actual_default["child_job_id"] = entry.get( "child_job_id").split("-")[0] for key in list(actual_default.keys()): if key not in default_attributes: del actual_default[key] csv_writer.writerow( [actual_default[new_key] for new_key in actual_default.keys()] + [new_attributes[new_key] for new_key in new_attributes.keys()]) else: entry.update({"reciprocal": reciprocal, "model": model}) if keymap: entry.update(new_attributes) print(entry) if args.list_keys: # only one config needed config = configs[list(configs.keys())[0]] options = Config.flatten(config.options) options = sorted(filter(lambda opt: "+++" not in opt, options), key=lambda opt: opt.lower()) if isinstance(args.list_keys, bool): sep = ", " else: sep = args.list_keys print("Default keys for CSV: ") print(*default_attributes.keys(), sep=sep) print("") print("Special keys: ") print(*["$folder", "$checkpoint", "$machine", "$base_model"], sep=sep) print("") print("Keys found in trace: ") print(*sorted(all_trace_keys), sep=sep) print("") print("Keys found in config: ") print(*options, sep=sep)
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # package command if args.command == "package": package_model(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print( "WARNING: No configuration specified; using " + args.config, file=sys.stderr, ) if not vars(args)["console.quiet"]: print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" if not vars(args)["console.quiet"]: print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): checkpoint_file = get_checkpoint_file(config, args.checkpoint) # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds def get_seed(what): seed = config.get(f"random_seed.{what}") if seed < 0 and config.get(f"random_seed.default") >= 0: import hashlib # we add an md5 hash to the default seed so that different PRNGs get a # different seed seed = ( config.get(f"random_seed.default") + int(hashlib.md5(what.encode()).hexdigest(), 16) ) % 0xFFFF # stay 32-bit return seed if get_seed("python") > -1: import random random.seed(get_seed("python")) if get_seed("torch") > -1: import torch torch.manual_seed(get_seed("torch")) if get_seed("numpy") > -1: import numpy.random numpy.random.seed(get_seed("numpy")) if get_seed("numba") > -1: import numpy as np, numba @numba.njit def seed_numba(seed): np.random.seed(seed) seed_numba(get_seed("numba")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.create(config) # let's go if args.command == "resume": if checkpoint_file is not None: checkpoint = load_checkpoint( checkpoint_file, config.get("job.device") ) job = Job.create_from( checkpoint, new_config=config, dataset=dataset ) else: job = Job.create(config, dataset) job.config.log( "No checkpoint found or specified, starting from scratch..." ) else: job = Job.create(config, dataset) job.run() except BaseException: tb = traceback.format_exc() config.log(tb, echo=False) raise
def __init__( self, config: Config, dataset: Dataset, configuration_key: str, vocab_size: int, init_for_load_only=False, ): super().__init__(config, dataset, configuration_key, init_for_load_only=init_for_load_only) # read config self.config.check("train.trace_level", ["batch", "epoch"]) self.vocab_size = vocab_size if self.get_option("modalities")[0] != "struct": raise ValueError("DKRL assumes that struct is the first modality") # set relation embedder dim # fixes the problem that for the search, relation and entity embeder dim # has to be set with a single config # CAREFULL: THIS ASSUMES THAT THE ENITY EMBEDER IS CREATED FIRST rel_emb_conf_key = configuration_key.replace("entity_embedder", "relation_embedder") if configuration_key == rel_emb_conf_key: raise ValueError("Cannot set the relation embedding size") config.set(f"{rel_emb_conf_key}.dim", self.dim) # create embedder for each modality self.embedder = torch.nn.ModuleDict() for modality in self.get_option("modalities"): # if dim of modality embedder is < 0 set it to parent embedder dim # e.g. when using dkrl, the text embedding dim should equal embedding dim # but when using literale, the text embedding dim can vary if self.get_option(f"{modality}.dim") < 0: config.set(f"{self.configuration_key}.{modality}.dim", self.dim) embedder = KgeEmbedder.create( config, dataset, f"{self.configuration_key}.{modality}", vocab_size=self.vocab_size, init_for_load_only=init_for_load_only) self.embedder[modality] = embedder # HACK # kwargs["indexes"] is set to None, if dkrl_embedder has # regularize_args.weighted set to False. # If the child_embedder has regularize_args.weighted set to True, # it tries to access kwargs["indexes"], which leads to an error # Set regularize_args.weighted to True, if it is set for the struct embedder if self.embedder["struct"].get_option("regularize_args.weighted"): config.set(self.configuration_key + ".regularize_args.weighted", True) # TODO handling negative dropout because using it with ax searches for now dropout = self.get_option("dropout") if dropout < 0: if config.get("train.auto_correct"): config.log("Setting {}.dropout to 0, " "was set to {}.".format(configuration_key, dropout)) dropout = 0 self.dropout = torch.nn.Dropout(dropout)
def create_parser(config, additional_args=[]): # define short option names short_options = { "dataset.name": "-d", "job.type": "-j", "train.max_epochs": "-e", "model": "-m", } # create parser for config parser_conf = argparse.ArgumentParser(add_help=False) for key, value in Config.flatten(config.options).items(): short = short_options.get(key) argtype = type(value) if argtype == bool: argtype = argparse_bool_type if short: parser_conf.add_argument("--" + key, short, type=argtype) else: parser_conf.add_argument("--" + key, type=argtype) # add additional arguments for key in additional_args: parser_conf.add_argument(key) # add argument to abort on outdated data parser_conf.add_argument( "--abort-when-cache-outdated", action="store_const", const=True, default=False, help="Abort processing when an outdated cached dataset file is found " "(see description of `dataset.pickle` configuration key). " "Default is to recompute such cache files.", ) # create main parsers and subparsers parser = argparse.ArgumentParser("kge") subparsers = parser.add_subparsers(title="command", dest="command") subparsers.required = True # start and its meta-commands parser_start = subparsers.add_parser( "start", help="Start a new job (create and run it)", parents=[parser_conf] ) parser_create = subparsers.add_parser( "create", help="Create a new job (but do not run it)", parents=[parser_conf] ) for p in [parser_start, parser_create]: p.add_argument("config", type=str, nargs="?") p.add_argument("--folder", "-f", type=str, help="Output folder to use") p.add_argument( "--run", default=p is parser_start, type=argparse_bool_type, help="Whether to immediately run the created job", ) # resume and its meta-commands parser_resume = subparsers.add_parser( "resume", help="Resume a prior job", parents=[parser_conf] ) parser_eval = subparsers.add_parser( "eval", help="Evaluate the result of a prior job", parents=[parser_conf] ) parser_valid = subparsers.add_parser( "valid", help="Evaluate the result of a prior job using validation data", parents=[parser_conf], ) parser_test = subparsers.add_parser( "test", help="Evaluate the result of a prior job using test data", parents=[parser_conf], ) for p in [parser_resume, parser_eval, parser_valid, parser_test]: p.add_argument("config", type=str) p.add_argument( "--checkpoint", type=str, help=( "Which checkpoint to use: 'default', 'last', 'best', a number " "or a file name" ), default="default", ) add_dump_parsers(subparsers) add_package_parser(subparsers) return parser
def __init__(self, config: Config, dataset: Dataset, parent_job: Job = None) -> None: from kge.job import EvaluationJob super().__init__(config, dataset, parent_job) self.model: KgeModel = KgeModel.create(config, dataset) self.optimizer = KgeOptimizer.create(config, self.model) self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer) self.loss = KgeLoss.create(config) self.abort_on_nan: bool = config.get("train.abort_on_nan") self.batch_size: int = config.get("train.batch_size") self.device: str = self.config.get("job.device") self.train_split = config.get("train.split") valid_conf = config.clone() valid_conf.set("job.type", "eval") if self.config.get("valid.split") != "": valid_conf.set("eval.split", self.config.get("valid.split")) valid_conf.set("eval.trace_level", self.config.get("valid.trace_level")) self.valid_job = EvaluationJob.create(valid_conf, dataset, parent_job=self, model=self.model) self.config.check("train.trace_level", ["batch", "epoch"]) self.trace_batch: bool = self.config.get( "train.trace_level") == "batch" self.epoch: int = 0 self.valid_trace: List[Dict[str, Any]] = [] self.is_prepared = False self.model.train() # attributes filled in by implementing classes self.loader = None self.num_examples = None self.type_str: Optional[str] = None #: Hooks run after training for an epoch. #: Signature: job, trace_entry self.post_epoch_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run before starting a batch. #: Signature: job self.pre_batch_hooks: List[Callable[[Job], Any]] = [] #: Hooks run before outputting the trace of a batch. Can modify trace entry. #: Signature: job, trace_entry self.post_batch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run before outputting the trace of an epoch. Can modify trace entry. #: Signature: job, trace_entry self.post_epoch_trace_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run after a validation job. #: Signature: job, trace_entry self.post_valid_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] #: Hooks run after training #: Signature: job, trace_entry self.post_train_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = [] if self.__class__ == TrainingJob: for f in Job.job_created_hooks: f(self)
def main(): # default config config = Config() # now parse the arguments parser = create_parser(config) args, unknown_args = parser.parse_known_args() # If there where unknown args, add them to the parser and reparse. The correctness # of these arguments will be checked later. if len(unknown_args) > 0: parser = create_parser( config, filter(lambda a: a.startswith("--"), unknown_args) ) args = parser.parse_args() # process meta-commands process_meta_command(args, "create", {"command": "start", "run": False}) process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"}) process_meta_command( args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"} ) process_meta_command( args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"} ) # dump command if args.command == "dump": dump(args) exit() # start command if args.command == "start": # use toy config file if no config given if args.config is None: args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml" print("WARNING: No configuration specified; using " + args.config) print("Loading configuration {}...".format(args.config)) config.load(args.config) # resume command if args.command == "resume": if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"): args.config += "/config.yaml" print("Resuming from configuration {}...".format(args.config)) config.load(args.config) config.folder = os.path.dirname(args.config) if not config.folder: config.folder = "." if not os.path.exists(config.folder): raise ValueError( "{} is not a valid config file for resuming".format(args.config) ) # overwrite configuration with command line arguments for key, value in vars(args).items(): if key in [ "command", "config", "run", "folder", "checkpoint", "abort_when_cache_outdated", ]: continue if value is not None: if key == "search.device_pool": value = "".join(value).split(",") try: if isinstance(config.get(key), bool): value = argparse_bool_type(value) except KeyError: pass config.set(key, value) if key == "model": config._import(value) # initialize output folder if args.command == "start": if args.folder is None: # means: set default config_name = os.path.splitext(os.path.basename(args.config))[0] config.folder = os.path.join( kge_base_dir(), "local", "experiments", datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name, ) else: config.folder = args.folder # catch errors to log them try: if args.command == "start" and not config.init_folder(): raise ValueError("output folder {} exists already".format(config.folder)) config.log("Using folder: {}".format(config.folder)) # determine checkpoint to resume (if any) if hasattr(args, "checkpoint"): if args.checkpoint == "default": if config.get("job.type") in ["eval", "valid"]: checkpoint_file = config.checkpoint_file("best") else: checkpoint_file = None # means last elif is_number(args.checkpoint, int) or args.checkpoint == "best": checkpoint_file = config.checkpoint_file(args.checkpoint) else: # otherwise, treat it as a filename checkpoint_file = args.checkpoint # disable processing of outdated cached dataset files globally Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated # log configuration config.log("Configuration:") config.log(yaml.dump(config.options), prefix=" ") config.log("git commit: {}".format(get_git_revision_short_hash()), prefix=" ") # set random seeds if config.get("random_seed.python") > -1: import random random.seed(config.get("random_seed.python")) if config.get("random_seed.torch") > -1: import torch torch.manual_seed(config.get("random_seed.torch")) if config.get("random_seed.numpy") > -1: import numpy.random numpy.random.seed(config.get("random_seed.numpy")) # let's go if args.command == "start" and not args.run: config.log("Job created successfully.") else: # load data dataset = Dataset.load(config) # let's go job = Job.create(config, dataset) if args.command == "resume": job.resume(checkpoint_file) job.run() except BaseException as e: tb = traceback.format_exc() config.log(tb, echo=False) raise e from None