Exemplo n.º 1
0
    def create(
        config: Config,
        dataset: Dataset,
        configuration_key: Optional[str] = None,
        init_for_load_only=False,
        create_embedders=True,
        parameter_client=None,
        max_partition_entities=0,
    ) -> "KgeModel":
        """Factory method for model creation."""
        try:
            if configuration_key is not None:
                model_name = config.get(configuration_key + ".type")
            else:
                model_name = config.get("model")
            config._import(model_name)
            class_name = config.get(model_name + ".class_name")
        except:
            raise Exception("Can't find {}.type in config".format(configuration_key))

        try:
            model = init_from(
                class_name,
                config.get("modules"),
                config=config,
                dataset=dataset,
                configuration_key=configuration_key,
                init_for_load_only=init_for_load_only,
                create_embedders=create_embedders,
                parameter_client=parameter_client,
                max_partition_entities=max_partition_entities,
            )
            model.to(config.get("job.device"))
            return model
        except:
            config.log(f"Failed to create model {model_name} (class {class_name}).")
            raise
Exemplo n.º 2
0
    def __init__(self,
                 config: Config,
                 dataset: Dataset,
                 parent_job: Job = None,
                 model=None) -> None:
        from kge.job import EvaluationJob

        super().__init__(config, dataset, parent_job)
        if model is None:
            self.model: KgeModel = KgeModel.create(config, dataset)
        else:
            self.model: KgeModel = model
        self.optimizer = KgeOptimizer.create(config, self.model)
        self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer)
        self.loss = KgeLoss.create(config)
        self.abort_on_nan: bool = config.get("train.abort_on_nan")
        self.batch_size: int = config.get("train.batch_size")
        self._subbatch_auto_tune: bool = config.get("train.subbatch_auto_tune")
        self._max_subbatch_size: int = config.get("train.subbatch_size")
        self.device: str = self.config.get("job.device")
        self.train_split = config.get("train.split")

        self.config.check("train.trace_level", ["batch", "epoch"])
        self.trace_batch: bool = self.config.get(
            "train.trace_level") == "batch"
        self.epoch: int = 0
        self.valid_trace: List[Dict[str, Any]] = []
        valid_conf = config.clone()
        valid_conf.set("job.type", "eval")
        if self.config.get("valid.split") != "":
            valid_conf.set("eval.split", self.config.get("valid.split"))
        valid_conf.set("eval.trace_level",
                       self.config.get("valid.trace_level"))
        self.valid_job = EvaluationJob.create(valid_conf,
                                              dataset,
                                              parent_job=self,
                                              model=self.model)

        # attributes filled in by implementing classes
        self.loader = None
        self.num_examples = None
        self.type_str: Optional[str] = None

        # Hooks run after validation. The corresponding valid trace entry can be found
        # in self.valid_trace[-1] Signature: job
        self.post_valid_hooks: List[Callable[[Job], Any]] = []

        if self.__class__ == TrainingJob:
            for f in Job.job_created_hooks:
                f(self)

        self.model.train()
Exemplo n.º 3
0
    def create_from(cls,
                    checkpoint: Dict,
                    new_config: Config = None,
                    dataset: Dataset = None,
                    parent_job=None,
                    parameter_client=None) -> Job:
        """
        Creates a Job based on a checkpoint
        Args:
            checkpoint: loaded checkpoint
            new_config: optional config object - overwrites options of config
                              stored in checkpoint
            dataset: dataset object
            parent_job: parent job (e.g. search job)

        Returns: Job based on checkpoint

        """
        from kge.model import KgeModel

        model: KgeModel = None
        # search jobs don't have a model
        if "model" in checkpoint and checkpoint["model"] is not None:
            model = KgeModel.create_from(checkpoint,
                                         new_config=new_config,
                                         dataset=dataset,
                                         parameter_client=parameter_client)
            config = model.config
            dataset = model.dataset
        else:
            config = Config.create_from(checkpoint)
            if new_config:
                config.load_config(new_config)
            dataset = Dataset.create_from(checkpoint, config, dataset)
        job = Job.create(config,
                         dataset,
                         parent_job,
                         model,
                         parameter_client=parameter_client,
                         init_for_load_only=True)
        job._load(checkpoint)
        job.config.log("Loaded checkpoint from {}...".format(
            checkpoint["file"]))
        return job
Exemplo n.º 4
0
Arquivo: eval.py Projeto: uma-pi1/kge
    def __init__(self, config: Config, dataset: Dataset, parent_job, model):
        super().__init__(config, dataset, parent_job)

        self.config = config
        self.dataset = dataset
        self.model = model
        self.batch_size = config.get("eval.batch_size")
        self.device = self.config.get("job.device")
        self.config.check("train.trace_level", ["example", "batch", "epoch"])
        self.trace_examples = self.config.get("eval.trace_level") == "example"
        self.trace_batch = (self.trace_examples
                            or self.config.get("train.trace_level") == "batch")
        self.eval_split = self.config.get("eval.split")
        self.epoch = -1

        # all done, run job_created_hooks if necessary
        if self.__class__ == EvaluationJob:
            for f in Job.job_created_hooks:
                f(self)
Exemplo n.º 5
0
    def __init__(self, config: Config, dataset: Dataset, parent_job, model):
        super().__init__(config, dataset, parent_job, model)

        training_loss_eval_config = config.clone()
        # TODO set train split to include validation data here
        #   once support is added
        #   Then reflect this change in the trace entries

        self._train_job = TrainingJob.create(
            config=training_loss_eval_config,
            parent_job=self,
            dataset=dataset,
            model=model,
            forward_only=True,
        )

        if self.__class__ == TrainingLossEvaluationJob:
            for f in Job.job_created_hooks:
                f(self)
Exemplo n.º 6
0
    def __init__(self, config: Config, dataset: Dataset, parent_job, model):
        super().__init__(config, dataset, parent_job, model)
        self.config.check(
            "entity_ranking.tie_handling",
            ["rounded_mean_rank", "best_rank", "worst_rank"],
        )
        self.tie_handling = self.config.get("entity_ranking.tie_handling")

        if self.__class__ == EntityRankingJob:
            for f in Job.job_created_hooks:
                f(self)

        max_k = min(
            self.dataset.num_entities(),
            max(self.config.get("entity_ranking.hits_at_k_s")),
        )
        self.hits_at_k_s = list(
            filter(lambda x: x <= max_k,
                   self.config.get("entity_ranking.hits_at_k_s")))
        self.filter_with_test = config.get("entity_ranking.filter_with_test")
Exemplo n.º 7
0
    def create_from(
        checkpoint: Dict,
        dataset: Optional[Dataset] = None,
        use_tmp_log_folder=True,
        new_config: Config = None,
    ) -> "KgeModel":
        """Loads a model from a checkpoint file of a training job or a packaged model.

        If dataset is specified, associates this dataset with the model. Otherwise uses
        the dataset used to train the model.

        If `use_tmp_log_folder` is set, the logs and traces are written to a temporary
        file. Otherwise, the files `kge.log` and `trace.yaml` will be created (or
        appended to) in the checkpoint's folder.

        """
        config = Config.create_from(checkpoint)
        if new_config:
            config.load_config(new_config)

        if use_tmp_log_folder:
            import tempfile

            config.log_folder = tempfile.mkdtemp(prefix="kge-")
        else:
            config.log_folder = checkpoint["folder"]
            if not config.log_folder or not os.path.exists(config.log_folder):
                config.log_folder = "."
        dataset = Dataset.create_from(checkpoint,
                                      config,
                                      dataset,
                                      preload_data=False)
        model = KgeModel.create(config, dataset, init_for_load_only=True)
        model.load(checkpoint["model"])
        model.eval()
        return model
Exemplo n.º 8
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        configuration_key: str,
        vocab_size: int,
        init_for_load_only=False,
    ):
        super().__init__(config,
                         dataset,
                         configuration_key,
                         init_for_load_only=init_for_load_only)

        # read config
        self.regularize = self.check_option("regularize", ["", "lp"])
        self.config.check("train.trace_level", ["batch", "epoch"])
        self.vocab_size = vocab_size
        self.filename = self.get_option("filename")
        self.num_layers = self.get_option("num_layers")

        # load numeric data
        with open(self.filename, "r") as f:
            data = list(map(lambda s: s.strip().split("\t"), f.readlines()))

        # returns entities in index order
        entities = self.dataset.entity_ids()

        ent_to_idx = {ent: idx for idx, ent in enumerate(entities)}
        numeric_data_ent_idx = []

        rel_to_idx = {}
        numeric_data_rel_idx = []
        numeric_data = []
        for t in data:
            ent = t[0]
            rel = t[1]
            value = float(t[2])

            if rel not in rel_to_idx:
                rel_to_idx[rel] = len(rel_to_idx)

            numeric_data_ent_idx.append(ent_to_idx[ent])
            numeric_data_rel_idx.append(rel_to_idx[rel])
            numeric_data.append(value)

        numeric_data_ent_idx = torch.tensor(numeric_data_ent_idx,
                                            dtype=torch.long)
        numeric_data_rel_idx = torch.tensor(numeric_data_rel_idx,
                                            dtype=torch.long)
        numeric_data = torch.tensor(numeric_data, dtype=torch.float32)

        # normalize numeric literals
        if self.get_option("normalization") == "min-max":
            for rel_idx in rel_to_idx.values():
                sel = (rel_idx == numeric_data_rel_idx)
                max_num = torch.max(numeric_data[sel])
                min_num = torch.min(numeric_data[sel])
                numeric_data[sel] = ((numeric_data[sel] - min_num) /
                                     (max_num - min_num + 1e-8))
        elif self.get_option("normalization") == "z-score":
            for rel_idx in rel_to_idx.values():
                sel = (rel_idx == numeric_data_rel_idx)
                mean = torch.mean(numeric_data[sel])

                # account for the fact that there might only be a single value
                # in that case torch.std would result in nan
                if torch.sum(sel) > 1:
                    std = torch.std(numeric_data[sel])
                else:
                    std = 0

                numeric_data[sel] = ((numeric_data[sel] - mean) / (std + 1e-8))
        else:
            raise ValueError("Unkown normalization option")

        num_lit = torch.zeros(
            [len(ent_to_idx), len(rel_to_idx)], dtype=torch.float32)

        num_lit[numeric_data_ent_idx, numeric_data_rel_idx] = numeric_data
        # includes all numeric literals for all entities, with the entities
        # being ordered by their index
        self.num_lit = num_lit.to(self.config.get("job.device"))

        if self.num_layers > 0:
            # initialize numeric MLP
            self.numeric_mlp = NumericMLP(
                input_dim=num_lit.shape[1],
                output_dim=self.dim,
                num_layers=self.num_layers,
                activation=self.get_option("activation"))

            if not init_for_load_only:
                # initialize weights
                for name, weights in self.numeric_mlp.named_parameters():
                    # set bias to zero
                    # https://cs231n.github.io/neural-networks-2/#init
                    if "bias" in name:
                        torch.nn.init.zeros_(weights)
                    else:
                        self.initialize(weights)
        else:
            self.dim = num_lit.shape[1]

        # TODO handling negative dropout because using it with ax searches for now
        dropout = self.get_option("dropout")
        if dropout < 0:
            if config.get("train.auto_correct"):
                config.log("Setting {}.dropout to 0, "
                           "was set to {}.".format(configuration_key, dropout))
                dropout = 0
        self.dropout = torch.nn.Dropout(dropout)
Exemplo n.º 9
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        configuration_key: str,
        vocab_size: int,
        parameter_client: "KgeParameterClient",
        complete_vocab_size,
        lapse_offset=0,
        init_for_load_only=False,
    ):
        super().__init__(
            config,
            dataset,
            configuration_key,
            vocab_size,
            init_for_load_only=init_for_load_only,
        )
        self.optimizer_dim = get_optimizer_dim(config, self.dim)
        self.optimizer_values = torch.zeros(
            (self.vocab_size, self.optimizer_dim),
            dtype=torch.float32,
            requires_grad=False,
        )

        self.complete_vocab_size = complete_vocab_size
        self.parameter_client = parameter_client
        self.lapse_offset = lapse_offset
        self.pulled_ids = None
        self.load_batch = self.config.get("job.distributed.load_batch")
        # global to local mapper only used in sync level partition
        self.global_to_local_mapper = torch.full(
            (self.dataset.num_entities(), ),
            -1,
            dtype=torch.long,
            device="cpu")

        # maps the local embeddings to the embeddings in lapse
        # used in optimizer
        self.local_to_lapse_mapper = torch.full((vocab_size, ),
                                                -1,
                                                dtype=torch.long,
                                                requires_grad=False)
        self.pull_dim = self.dim + self.optimizer_dim
        self.unnecessary_dim = self.parameter_client.dim - self.pull_dim

        # 3 pull tensors to pre-pull up to 3 batches
        # first boolean denotes if the tensor is free
        self.pull_tensors = [
            [
                True,
                torch.empty(
                    (self.vocab_size, self.parameter_client.dim),
                    # (self.vocab_size, self.dim + self.optimizer_dim),
                    dtype=torch.float32,
                    device="cpu",
                    requires_grad=False,
                ),
            ],
            [
                True,
                torch.empty(
                    (self.vocab_size, self.parameter_client.dim),
                    # (self.vocab_size, self.dim + self.optimizer_dim),
                    dtype=torch.float32,
                    device="cpu",
                    requires_grad=False,
                ),
            ],
            [
                True,
                torch.empty(
                    (self.vocab_size, self.parameter_client.dim),
                    # (self.vocab_size, self.dim + self.optimizer_dim),
                    dtype=torch.float32,
                    device="cpu",
                    requires_grad=False,
                ),
            ],
        ]
        if "cuda" in config.get("job.device"):
            # only pin tensors if we are using gpu
            # otherwise gpu memory will be allocated for no reason
            with torch.cuda.device(config.get("job.device")):
                for i in range(len(self.pull_tensors)):
                    self.pull_tensors[i][1] = self.pull_tensors[i][
                        1].pin_memory()

        self.num_pulled = 0
        self.mapping_time = 0.0
        # self.pre_pulled = None
        self.pre_pulled = deque()
Exemplo n.º 10
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        scorer: Union[RelationalScorer, type],
        create_embedders=True,
        configuration_key=None,
        init_for_load_only=False,
        parameter_client=None,
        max_partition_entities=0,
    ):
        super().__init__(config, dataset, configuration_key)

        # TODO support different embedders for subjects and objects

        #: Embedder used for entities (both subject and objects)
        self._entity_embedder: KgeEmbedder

        #: Embedder used for relations
        self._relation_embedder: KgeEmbedder

        if create_embedders:
            self._create_embedders(init_for_load_only)
        elif False:
            #if self.get_option("create_complete"):
            #    embedding_layer_size = dataset.num_entities()
            if config.get("job.distributed.entity_sync_level") == "partition" and max_partition_entities != 0:
                embedding_layer_size =max_partition_entities
            else:
                embedding_layer_size = self._calc_embedding_layer_size(config, dataset)
            config.log(f"creating entity_embedder with {embedding_layer_size} keys")
            self._entity_embedder = KgeEmbedder.create(
                config=config,
                dataset=dataset,
                configuration_key=self.configuration_key + ".entity_embedder",
                #dataset.num_entities(),
                vocab_size=embedding_layer_size,
                init_for_load_only=init_for_load_only,
                parameter_client=parameter_client,
                lapse_offset=0,
                complete_vocab_size=dataset.num_entities()
            )

            #: Embedder used for relations
            num_relations = dataset.num_relations()
            self._relation_embedder = KgeEmbedder.create(
                config,
                dataset,
                self.configuration_key + ".relation_embedder",
                num_relations,
                init_for_load_only=init_for_load_only,
                parameter_client=parameter_client,
                lapse_offset=dataset.num_entities(),
                complete_vocab_size=dataset.num_relations(),
            )

            if not init_for_load_only and parameter_client.rank == get_min_rank(config):
                # load pretrained embeddings
                pretrained_entities_filename = ""
                pretrained_relations_filename = ""
                if self.has_option("entity_embedder.pretrain.model_filename"):
                    pretrained_entities_filename = self.get_option(
                        "entity_embedder.pretrain.model_filename"
                    )
                if self.has_option("relation_embedder.pretrain.model_filename"):
                    pretrained_relations_filename = self.get_option(
                        "relation_embedder.pretrain.model_filename"
                    )

                def load_pretrained_model(
                    pretrained_filename: str,
                ) -> Optional[KgeModel]:
                    if pretrained_filename != "":
                        self.config.log(
                            f"Initializing with embeddings stored in "
                            f"{pretrained_filename}"
                        )
                        checkpoint = load_checkpoint(pretrained_filename)
                        return KgeModel.create_from(checkpoint, parameter_client=parameter_client)
                    return None

                pretrained_entities_model = load_pretrained_model(
                    pretrained_entities_filename
                )
                if pretrained_entities_filename == pretrained_relations_filename:
                    pretrained_relations_model = pretrained_entities_model
                else:
                    pretrained_relations_model = load_pretrained_model(
                        pretrained_relations_filename
                    )
                if pretrained_entities_model is not None:
                    if (
                        pretrained_entities_model.get_s_embedder()
                        != pretrained_entities_model.get_o_embedder()
                    ):
                        raise ValueError(
                            "Can only initialize with pre-trained models having "
                            "identical subject and object embeddings."
                        )
                    self._entity_embedder.init_pretrained(
                        pretrained_entities_model.get_s_embedder()
                    )
                if pretrained_relations_model is not None:
                    self._relation_embedder.init_pretrained(
                        pretrained_relations_model.get_p_embedder()
                    )

        #: Scorer
        self._scorer: RelationalScorer
        if type(scorer) == type:
            # scorer is type of the scorer to use; call its constructor
            self._scorer = scorer(
                config=config, dataset=dataset, configuration_key=self.configuration_key
            )
        else:
            self._scorer = scorer
Exemplo n.º 11
0
    def create_default(
        model: Optional[str] = None,
        dataset: Optional[Union[Dataset, str]] = None,
        options: Dict[str, Any] = {},
        folder: Optional[str] = None,
    ) -> "KgeModel":
        """Utility method to create a model, including configuration and dataset.

        `model` is the name of the model (takes precedence over
        ``options["model"]``), `dataset` a dataset name or `Dataset` instance (takes
        precedence over ``options["dataset.name"]``), and options arbitrary other
        configuration options.

        If `folder` is ``None``, creates a temporary folder. Otherwise uses the
        specified folder.

        """
        # load default model config
        if model is None:
            model = options["model"]
        default_config_file = filename_in_module(kge.model, "{}.yaml".format(model))
        config = Config()
        config.load(default_config_file, create=True)

        # apply specified options
        config.set("model", model)
        if isinstance(dataset, Dataset):
            config.set("dataset.name", dataset.config.get("dataset.name"))
        elif isinstance(dataset, str):
            config.set("dataset.name", dataset)
        config.set_all(new_options=options)

        # create output folder
        if folder is None:
            config.folder = tempfile.mkdtemp(
                "{}-{}-".format(config.get("dataset.name"), config.get("model"))
            )
        else:
            config.folder = folder

        # create dataset and model
        if not isinstance(dataset, Dataset):
            dataset = Dataset.create(config)
        model = KgeModel.create(config, dataset)
        return model
Exemplo n.º 12
0
def _dump_config(args):
    """Execute the 'dump config' command."""
    if not (args.raw or args.full or args.minimal):
        args.minimal = True

    if args.raw + args.full + args.minimal != 1:
        raise ValueError(
            "Exactly one of --raw, --full, or --minimal must be set")

    if args.raw and (args.include or args.exclude):
        raise ValueError("--include and --exclude cannot be used with --raw "
                         "(use --full or --minimal instead).")

    config = Config()
    config_file = None
    if os.path.isdir(args.source):
        config_file = os.path.join(args.source, "config.yaml")
        config.load(config_file)
    elif ".yaml" in os.path.split(args.source)[-1]:
        config_file = args.source
        config.load(config_file)
    else:  # a checkpoint
        checkpoint = torch.load(args.source, map_location="cpu")
        if args.raw:
            config = checkpoint["config"]
        else:
            config.load_config(checkpoint["config"])

    def print_options(options):
        # drop all arguments that are not included
        if args.include:
            args.include = set(args.include)
            options_copy = copy.deepcopy(options)
            for key in options_copy.keys():
                prefix = key
                keep = False
                while True:
                    if prefix in args.include:
                        keep = True
                        break
                    else:
                        last_dot_index = prefix.rfind(".")
                        if last_dot_index < 0:
                            break
                        else:
                            prefix = prefix[:last_dot_index]
                if not keep:
                    del options[key]

        # remove all arguments that are excluded
        if args.exclude:
            args.exclude = set(args.exclude)
            options_copy = copy.deepcopy(options)
            for key in options_copy.keys():
                prefix = key
                while True:
                    if prefix in args.exclude:
                        del options[key]
                        break
                    else:
                        last_dot_index = prefix.rfind(".")
                        if last_dot_index < 0:
                            break
                        else:
                            prefix = prefix[:last_dot_index]

        # convert the remaining options to a Config and print it
        config = Config(load_default=False)
        config.set_all(options, create=True)
        print(yaml.dump(config.options))

    if args.raw:
        if config_file:
            with open(config_file, "r") as f:
                print(f.read())
        else:
            print_options(config.options)
    elif args.full:
        print_options(config.options)
    else:  # minimal
        default_config = Config()
        imports = config.get("import")
        if imports is not None:
            if not isinstance(imports, list):
                imports = [imports]
            for module_name in imports:
                default_config._import(module_name)
        default_options = Config.flatten(default_config.options)
        new_options = Config.flatten(config.options)
        minimal_options = {}

        for option, value in new_options.items():
            if option not in default_options or default_options[
                    option] != value:
                minimal_options[option] = value

        # always retain all imports
        if imports is not None:
            minimal_options["import"] = list(set(imports))

        print_options(minimal_options)
Exemplo n.º 13
0
def create_config(test_dataset_name: str, model: str = "complex") -> Config:
    config = Config()
    config.folder = None
    config.set("console.quiet", True)
    config.set("model", model)
    config._import(model)
    config.set("dataset.name", test_dataset_name)
    config.set("job.device", "cpu")
    return config
Exemplo n.º 14
0
def _dump_trace(args):
    """ Executes the 'dump trace' command."""
    start = time.time()
    if (args.train or args.valid or args.test) and args.search:
        print(
            "--search and --train, --valid, --test are mutually exclusive",
            file=sys.stderr,
        )
        exit(1)
    entry_type_specified = True
    if not (args.train or args.valid or args.test or args.search):
        entry_type_specified = False
        args.train = True
        args.valid = True
        args.test = True

    checkpoint_path = None
    if ".pt" in os.path.split(args.source)[-1]:
        checkpoint_path = args.source
        folder_path = os.path.split(args.source)[0]
    else:
        # determine job_id and epoch from last/best checkpoint automatically
        if args.checkpoint:
            checkpoint_path = Config.get_best_or_last_checkpoint(args.source)
        folder_path = args.source
        if not args.checkpoint and args.truncate:
            raise ValueError(
                "You can only use --truncate when a checkpoint is specified."
                "Consider using --checkpoint or provide a checkpoint file as source"
            )
    trace = os.path.join(folder_path, "trace.yaml")
    if not os.path.isfile(trace):
        sys.stderr.write("No trace found at {}\n".format(trace))
        exit(1)

    keymap = OrderedDict()
    additional_keys = []
    if args.keysfile:
        with open(args.keysfile, "r") as keyfile:
            additional_keys = keyfile.readlines()
    if args.keys:
        additional_keys += args.keys
    for line in additional_keys:
        line = line.rstrip("\n").replace(" ", "")
        name_key = line.split("=")
        if len(name_key) == 1:
            name_key += name_key
        keymap[name_key[0]] = name_key[1]

    job_id = None
    epoch = int(args.max_epoch)
    # use job_id and epoch from checkpoint
    if checkpoint_path and args.truncate:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
        epoch = checkpoint["epoch"]
    # only use job_id from checkpoint
    elif checkpoint_path:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
    # override job_id and epoch with user arguments
    if args.job_id:
        job_id = args.job_id
    if not epoch:
        epoch = float("inf")

    entries, job_epochs = [], {}
    if not args.search:
        entries, job_epochs = Trace.grep_training_trace_entries(
            tracefile=trace,
            train=args.train,
            test=args.test,
            valid=args.valid,
            example=args.example,
            batch=args.batch,
            job_id=job_id,
            epoch_of_last=epoch,
        )
    if not entries and (args.search or not entry_type_specified):
        entries = Trace.grep_entries(tracefile=trace,
                                     conjunctions=[f"scope: train"])
        epoch = None
        if entries:
            args.search = True
    if not entries:
        print("No relevant trace entries found.", file=sys.stderr)
        exit(1)

    middle = time.time()
    if not args.yaml:
        csv_writer = csv.writer(sys.stdout)
        # dict[new_name] = (lookup_name, where)
        # if where=="config"/"trace" it will be looked up automatically
        # if where=="sep" it must be added in in the write loop separately
        if args.no_default_keys:
            default_attributes = OrderedDict()
        else:
            default_attributes = OrderedDict([
                ("job_id", ("job_id", "sep")),
                ("dataset", ("dataset.name", "config")),
                ("model", ("model", "sep")),
                ("reciprocal", ("reciprocal", "sep")),
                ("job", ("job", "sep")),
                ("job_type", ("type", "trace")),
                ("split", ("split", "sep")),
                ("epoch", ("epoch", "trace")),
                ("avg_loss", ("avg_loss", "trace")),
                ("avg_penalty", ("avg_penalty", "trace")),
                ("avg_cost", ("avg_cost", "trace")),
                ("metric_name", ("valid.metric", "config")),
                ("metric", ("metric", "sep")),
            ])
            if args.search:
                default_attributes["child_folder"] = ("folder", "trace")
                default_attributes["child_job_id"] = ("child_job_id", "sep")

        if not args.no_header:
            csv_writer.writerow(
                list(default_attributes.keys()) +
                [key for key in keymap.keys()])
    # store configs for job_id's s.t. they need to be loaded only once
    configs = {}
    warning_shown = False
    for entry in entries:
        if epoch and not entry.get("epoch") <= float(epoch):
            continue
        # filter out not needed entries from a previous job when
        # a job was resumed from the middle
        if entry.get("job") == "train":
            job_id = entry.get("job_id")
            if entry.get("epoch") > job_epochs[job_id]:
                continue

        # find relevant config file
        child_job_id = entry.get(
            "child_job_id") if "child_job_id" in entry else None
        config_key = (entry.get("folder") + "/" + str(child_job_id)
                      if args.search else entry.get("job_id"))
        if config_key in configs.keys():
            config = configs[config_key]
        else:
            if args.search:
                if not child_job_id and not warning_shown:
                    # This warning is from Dec 19, 2019. TODO remove
                    print(
                        "Warning: You are dumping the trace of an older search job. "
                        "This is fine only if "
                        "the config.yaml files in each subfolder have not been modified "
                        "after running the corresponding training job.",
                        file=sys.stderr,
                    )
                    warning_shown = True
                config = get_config_for_job_id(
                    child_job_id, os.path.join(folder_path,
                                               entry.get("folder")))
                entry["type"] = config.get("train.type")
            else:
                config = get_config_for_job_id(entry.get("job_id"),
                                               folder_path)
            configs[config_key] = config

        new_attributes = OrderedDict()
        if config.get_default("model") == "reciprocal_relations_model":
            model = config.get_default(
                "reciprocal_relations_model.base_model.type")
            # the string that substitutes $base_model in keymap if it exists
            subs_model = "reciprocal_relations_model.base_model"
            reciprocal = 1
        else:
            model = config.get_default("model")
            subs_model = model
            reciprocal = 0
        for new_key in keymap.keys():
            lookup = keymap[new_key]
            if "$base_model" in lookup:
                lookup = lookup.replace("$base_model", subs_model)
            try:
                if lookup == "$folder":
                    val = os.path.abspath(folder_path)
                elif lookup == "$checkpoint":
                    val = os.path.abspath(checkpoint_path)
                elif lookup == "$machine":
                    val = socket.gethostname()
                else:
                    val = config.get_default(lookup)
            except:
                # creates empty field if key is not existing
                val = entry.get(lookup)
            if type(val) == bool and val:
                val = 1
            elif type(val) == bool and not val:
                val = 0
            new_attributes[new_key] = val
        if not args.yaml:
            # find the actual values for the default attributes
            actual_default = default_attributes.copy()
            for new_key in default_attributes.keys():
                lookup, where = default_attributes[new_key]
                if where == "config":
                    actual_default[new_key] = config.get(lookup)
                elif where == "trace":
                    actual_default[new_key] = entry.get(lookup)
            # keys with separate treatment
            # "split" in {train,test,valid} for the datatype
            # "job" in {train,eval,valid,search}
            if entry.get("job") == "train":
                actual_default["split"] = "train"
                actual_default["job"] = "train"
            elif entry.get("job") == "eval":
                actual_default["split"] = entry.get("data")  # test or valid
                if entry.get("resumed_from_job_id"):
                    actual_default["job"] = "eval"  # from "kge eval"
                else:
                    actual_default["job"] = "valid"  # child of training job
            else:
                actual_default["job"] = entry.get("job")
                actual_default["split"] = entry.get("data")
            actual_default["job_id"] = entry.get("job_id").split("-")[0]
            actual_default["model"] = model
            actual_default["reciprocal"] = reciprocal
            # lookup name is in config value is in trace
            actual_default["metric"] = entry.get(
                config.get_default("valid.metric"))
            if args.search:
                actual_default["child_job_id"] = entry.get(
                    "child_job_id").split("-")[0]
            for key in list(actual_default.keys()):
                if key not in default_attributes:
                    del actual_default[key]
            csv_writer.writerow(
                [actual_default[new_key]
                 for new_key in actual_default.keys()] +
                [new_attributes[new_key] for new_key in new_attributes.keys()])
        else:
            entry.update({"reciprocal": reciprocal, "model": model})
            if keymap:
                entry.update(new_attributes)
            sys.stdout.write(re.sub("[{}']", "", str(entry)))
            sys.stdout.write("\n")
    end = time.time()
    if args.timeit:
        sys.stdout.write("Grep + processing took {} \n".format(middle - start))
        sys.stdout.write("Writing took {}".format(end - middle))
Exemplo n.º 15
0
    def run(self):
        # read search configurations and expand them to full configs
        search_configs = copy.deepcopy(
            self.config.get("manual_search.configurations"))
        all_keys = set()
        for i in range(len(search_configs)):
            search_config = search_configs[i]
            folder = search_config["folder"]
            del search_config["folder"]
            config = self.config.clone(folder)
            config.set("job.type", "train")
            config.options.pop("manual_search",
                               None)  # could be large, don't copy
            flattened_search_config = Config.flatten(search_config)
            config.set_all(flattened_search_config)
            all_keys.update(flattened_search_config.keys())
            search_configs[i] = config

        # create folders for search configs (existing folders remain
        # unmodified)
        for config in search_configs:
            config.init_folder()

        # TODO find a way to create all indexes before running the jobs. The quick hack
        # below does not work becuase pytorch then throws a "too many open files" error
        # self.dataset.index("train_sp_to_o")
        # self.dataset.index("train_po_to_s")
        # self.dataset.index("valid_sp_to_o")
        # self.dataset.index("valid_po_to_s")
        # self.dataset.index("test_sp_to_o")
        # self.dataset.index("test_po_to_s")

        # now start running/resuming
        for i, config in enumerate(search_configs):
            task_arg = (self, i, config, len(search_configs), all_keys)
            self.submit_task(kge.job.search._run_train_job, task_arg)
        self.wait_task(concurrent.futures.ALL_COMPLETED)

        # if not running the jobs, stop here
        if not self.config.get("manual_search.run"):
            self.config.log(
                "Skipping evaluation of results as requested by user.")
            return

        # collect results
        best_per_job = [None] * len(search_configs)
        best_metric_per_job = [None] * len(search_configs)
        for ibm in self.ready_task_results:
            i, best, best_metric = ibm
            best_per_job[i] = best
            best_metric_per_job[i] = best_metric

        # produce an overall summary
        self.config.log("Result summary:")
        metric_name = self.config.get("valid.metric")
        overall_best = None
        overall_best_metric = None
        for i in range(len(search_configs)):
            best = best_per_job[i]
            best_metric = best_metric_per_job[i]
            if not overall_best or overall_best_metric < best_metric:
                overall_best = best
                overall_best_metric = best_metric
            self.config.log(
                "{}={:.3f} after {} epochs in folder {}".format(
                    metric_name, best_metric, best["epoch"], best["folder"]),
                prefix="  ",
            )
        self.config.log("And the winner is:")
        self.config.log(
            "{}={:.3f} after {} epochs in folder {}".format(
                metric_name,
                overall_best_metric,
                overall_best["epoch"],
                overall_best["folder"],
            ),
            prefix="  ",
        )
        self.config.log("Best overall result:")
        self.trace(event="search_completed",
                   echo=True,
                   echo_prefix="  ",
                   log=True,
                   scope="search",
                   **overall_best)
Exemplo n.º 16
0
    def create(config: Config):
        """Factory method for loss function instantiation."""

        # perhaps TODO: try class with specified name -> extensibility
        config.check(
            "train.loss",
            [
                "bce",
                "bce_mean",
                "bce_self_adversarial",
                "margin_ranking",
                "ce",
                "kl",
                "soft_margin",
            ],
        )
        if config.get("train.loss") == "bce":
            offset = config.get("train.loss_arg")
            if math.isnan(offset):
                offset = 0.0
                config.set("train.loss_arg", offset, log=True)
            return BCEWithLogitsKgeLoss(config, offset=offset, bce_type=None)
        elif config.get("train.loss") == "bce_mean":
            offset = config.get("train.loss_arg")
            if math.isnan(offset):
                offset = 0.0
                config.set("train.loss_arg", offset, log=True)
            return BCEWithLogitsKgeLoss(config, offset=offset, bce_type="mean")
        elif config.get("train.loss") == "bce_self_adversarial":
            offset = config.get("train.loss_arg")
            if math.isnan(offset):
                offset = 0.0
                config.set("train.loss_arg", offset, log=True)
            try:
                temperature = float(
                    config.get("user.bce_self_adversarial_temperature"))
            except KeyError:
                temperature = 1.0
            config.log(f"Using adversarial temperature {temperature}")
            return BCEWithLogitsKgeLoss(
                config,
                offset=offset,
                bce_type="self_adversarial",
                temperature=temperature,
            )
        elif config.get("train.loss") == "kl":
            return KLDivWithSoftmaxKgeLoss(config)
        elif config.get("train.loss") == "margin_ranking":
            margin = config.get("train.loss_arg")
            if math.isnan(margin):
                margin = 1.0
                config.set("train.loss_arg", margin, log=True)
            return MarginRankingKgeLoss(config, margin=margin)
        elif config.get("train.loss") == "soft_margin":
            return SoftMarginKgeLoss(config)
        else:
            raise ValueError("invalid value train.loss={}".format(
                config.get("train.loss")))
Exemplo n.º 17
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        parent_job: Job = None,
        model=None,
        optimizer=None,
        forward_only=False,
        parameter_client=None,
    ) -> None:
        from kge.job import EvaluationJob

        super().__init__(config,
                         dataset,
                         parent_job,
                         parameter_client=parameter_client)

        if model is None:
            self.model: KgeModel = KgeModel.create(
                config,
                dataset,
            )
        else:
            self.model: KgeModel = model
        self.loss = KgeLoss.create(config)
        self.abort_on_nan: bool = config.get("train.abort_on_nan")
        self.batch_size: int = config.get("train.batch_size")
        self._subbatch_auto_tune: bool = config.get("train.subbatch_auto_tune")
        self._max_subbatch_size: int = config.get("train.subbatch_size")
        self.device: str = self.config.get("job.device")
        self.train_split = config.get("train.split")

        self.config.check("train.trace_level", ["batch", "epoch"])
        self.trace_batch: bool = self.config.get(
            "train.trace_level") == "batch"
        self.epoch: int = 0
        self.is_forward_only = forward_only

        if not self.is_forward_only:
            self.model.train()

            if optimizer is None:
                self.optimizer = KgeOptimizer.create(
                    config,
                    self.model,
                )
            else:
                self.optimizer = optimizer
            self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer)
            self._lr_warmup = self.config.get("train.lr_warmup")
            for group in self.optimizer.param_groups:
                group["initial_lr"] = group["lr"]

            self.valid_trace: List[Dict[str, Any]] = []
            valid_conf = config.clone()
            valid_conf.set("job.type", "eval")
            if self.config.get("valid.split") != "":
                valid_conf.set("eval.split", self.config.get("valid.split"))
            valid_conf.set("eval.trace_level",
                           self.config.get("valid.trace_level"))
            self.valid_job = EvaluationJob.create(valid_conf,
                                                  dataset,
                                                  parent_job=self,
                                                  model=self.model)

        # attributes filled in by implementing classes
        self.loader = None
        self.num_examples = None
        self.type_str: Optional[str] = None

        # Hooks run after validation. The corresponding valid trace entry can be found
        # in self.valid_trace[-1] Signature: job
        self.post_valid_hooks: List[Callable[[Job], Any]] = []

        # Hooks run on early stopping
        self.early_stop_hooks: List[Callable[[Job], Any]] = []

        # Hooks to add conditions to stop early
        # The hooked function needs to return a boolean
        self.early_stop_conditions: List[Callable[[Job], Any]] = []

        if self.__class__ == TrainingJob:
            for f in Job.job_created_hooks:
                f(self)
Exemplo n.º 18
0
def _dump_trace(args):
    """Execute the 'dump trace' command."""
    if (args.train or args.valid or args.test or args.truncate or args.job_id
            or args.checkpoint or args.batch or args.example) and args.search:
        sys.exit(
            "--search and any of --train, --valid, --test, --truncate, --job_id,"
            " --checkpoint, --batch, --example are mutually exclusive")

    entry_type_specified = True
    if not (args.train or args.valid or args.test or args.search):
        entry_type_specified = False
        args.train = True
        args.valid = True
        args.test = True

    truncate_flag = False
    truncate_epoch = None
    if isinstance(args.truncate, bool) and args.truncate:
        truncate_flag = True
    elif not isinstance(args.truncate, bool):
        if not args.truncate.isdigit():
            sys.exit(
                "Integer argument or no argument for --truncate must be used")
        truncate_epoch = int(args.truncate)

    checkpoint_path = None
    if ".pt" in os.path.split(args.source)[-1]:
        checkpoint_path = args.source
        folder_path = os.path.split(args.source)[0]
    else:
        # determine job_id and epoch from last/best checkpoint automatically
        if args.checkpoint:
            checkpoint_path = Config.best_or_last_checkpoint_file(args.source)
        folder_path = args.source
    if not checkpoint_path and truncate_flag:
        sys.exit(
            "--truncate can only be used as a flag when a checkpoint is specified."
            " Consider specifying a checkpoint or use an integer argument for the"
            " --truncate option")
    if checkpoint_path and args.job_id:
        sys.exit(
            "--job_id cannot be used together with a checkpoint as the checkpoint"
            " already specifies the job_id")
    trace = os.path.join(folder_path, "trace.yaml")
    if not os.path.isfile(trace):
        sys.exit(
            f"No file 'trace.yaml' found at {os.path.abspath(folder_path)}")

    # process additional keys from --keys and --keysfile
    keymap = OrderedDict()
    additional_keys = []
    if args.keysfile:
        with open(args.keysfile, "r") as keyfile:
            additional_keys = keyfile.readlines()
    if args.keys:
        additional_keys += args.keys
    for line in additional_keys:
        line = line.rstrip("\n").replace(" ", "")
        name_key = line.split("=")
        if len(name_key) == 1:
            name_key += name_key
        keymap[name_key[0]] = name_key[1]

    job_id = None
    # use job_id and truncate_epoch from checkpoint
    if checkpoint_path and truncate_flag:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
        truncate_epoch = checkpoint["epoch"]
    # only use job_id from checkpoint
    elif checkpoint_path:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
    # no checkpoint specified job_id might have been set manually
    elif args.job_id:
        job_id = args.job_id
    # don't restrict epoch number in case it has not been specified yet
    if not truncate_epoch:
        truncate_epoch = float("inf")

    entries, job_epochs = [], {}
    if not args.search:
        entries, job_epochs = Trace.grep_training_trace_entries(
            tracefile=trace,
            train=args.train,
            test=args.test,
            valid=args.valid,
            example=args.example,
            batch=args.batch,
            job_id=job_id,
            epoch_of_last=truncate_epoch,
        )
    if not entries and (args.search or not entry_type_specified):
        entries = Trace.grep_entries(tracefile=trace,
                                     conjunctions=[f"scope: train"])
        truncate_epoch = None
        if entries:
            args.search = True
    if not entries and entry_type_specified:
        sys.exit(
            "No relevant trace entries found. If this was a trace from a search"
            " job, dont use any of --train --valid --test.")
    elif not entries:
        sys.exit("No relevant trace entries found.")

    if args.list_keys:
        all_trace_keys = set()

    if not args.yaml:
        csv_writer = csv.writer(sys.stdout)
        # dict[new_name] = (lookup_name, where)
        # if where=="config"/"trace" it will be looked up automatically
        # if where=="sep" it must be added in in the write loop separately
        if args.no_default_keys:
            default_attributes = OrderedDict()
        else:
            default_attributes = OrderedDict([
                ("job_id", ("job_id", "sep")),
                ("dataset", ("dataset.name", "config")),
                ("model", ("model", "sep")),
                ("reciprocal", ("reciprocal", "sep")),
                ("job", ("job", "sep")),
                ("job_type", ("type", "trace")),
                ("split", ("split", "sep")),
                ("epoch", ("epoch", "trace")),
                ("avg_loss", ("avg_loss", "trace")),
                ("avg_penalty", ("avg_penalty", "trace")),
                ("avg_cost", ("avg_cost", "trace")),
                ("metric_name", ("valid.metric", "config")),
                ("metric", ("metric", "sep")),
            ])
            if args.search:
                default_attributes["child_folder"] = ("folder", "trace")
                default_attributes["child_job_id"] = ("child_job_id", "sep")

        if not (args.no_header or args.list_keys):
            csv_writer.writerow(
                list(default_attributes.keys()) +
                [key for key in keymap.keys()])
    # store configs for job_id's s.t. they need to be loaded only once
    configs = {}
    warning_shown = False
    for entry in entries:
        current_epoch = entry.get("epoch")
        job_type = entry.get("job")
        job_id = entry.get("job_id")
        if truncate_epoch and not current_epoch <= float(truncate_epoch):
            continue
        # filter out entries not relevant to the unique training sequence determined
        # by the options; not relevant for search
        if job_type == "train":
            if current_epoch > job_epochs[job_id]:
                continue
        elif job_type == "eval":
            if "resumed_from_job_id" in entry:
                if current_epoch > job_epochs[entry.get(
                        "resumed_from_job_id")]:
                    continue
            elif "parent_job_id" in entry:
                if current_epoch > job_epochs[entry.get("parent_job_id")]:
                    continue
        # find relevant config file
        child_job_id = entry.get(
            "child_job_id") if "child_job_id" in entry else None
        config_key = (entry.get("folder") + "/" +
                      str(child_job_id) if args.search else job_id)
        if config_key in configs.keys():
            config = configs[config_key]
        else:
            if args.search:
                if not child_job_id and not warning_shown:
                    # This warning is from Dec 19, 2019. TODO remove
                    print(
                        "Warning: You are dumping the trace of an older search job. "
                        "This is fine only if "
                        "the config.yaml files in each subfolder have not been modified "
                        "after running the corresponding training job.",
                        file=sys.stderr,
                    )
                    warning_shown = True
                config = get_config_for_job_id(
                    child_job_id, os.path.join(folder_path,
                                               entry.get("folder")))
                entry["type"] = config.get("train.type")
            else:
                config = get_config_for_job_id(job_id, folder_path)
            configs[config_key] = config
        if args.list_keys:
            all_trace_keys.update(entry.keys())
            continue
        new_attributes = OrderedDict()
        # when training was reciprocal, use the base_model as model
        if config.get_default("model") == "reciprocal_relations_model":
            model = config.get_default(
                "reciprocal_relations_model.base_model.type")
            # the string that substitutes $base_model in keymap if it exists
            subs_model = "reciprocal_relations_model.base_model"
            reciprocal = 1
        else:
            model = config.get_default("model")
            subs_model = model
            reciprocal = 0
        # search for the additional keys from --keys and --keysfile
        for new_key in keymap.keys():
            lookup = keymap[new_key]
            # search for special keys
            value = None
            if lookup == "$folder":
                value = os.path.abspath(folder_path)
            elif lookup == "$checkpoint" and checkpoint_path:
                value = os.path.abspath(checkpoint_path)
            elif lookup == "$machine":
                value = socket.gethostname()
            if "$base_model" in lookup:
                lookup = lookup.replace("$base_model", subs_model)
            # search for ordinary keys; start searching in trace entry then config
            if not value:
                value = entry.get(lookup)
            if not value:
                try:
                    value = config.get_default(lookup)
                except:
                    pass  # value stays None; creates empty field in csv
            if value and isinstance(value, bool):
                value = 1
            elif not value and isinstance(value, bool):
                value = 0
            new_attributes[new_key] = value
        if not args.yaml:
            # find the actual values for the default attributes
            actual_default = default_attributes.copy()
            for new_key in default_attributes.keys():
                lookup, where = default_attributes[new_key]
                if where == "config":
                    actual_default[new_key] = config.get(lookup)
                elif where == "trace":
                    actual_default[new_key] = entry.get(lookup)
            # keys with separate treatment
            # "split" in {train,test,valid} for the datatype
            # "job" in {train,eval,valid,search}
            if job_type == "train":
                if "split" in entry:
                    actual_default["split"] = entry.get("split")
                else:
                    actual_default["split"] = "train"
                actual_default["job"] = "train"
            elif job_type == "eval":
                if "split" in entry:
                    actual_default["split"] = entry.get(
                        "split")  # test or valid
                else:
                    # deprecated
                    actual_default["split"] = entry.get(
                        "data")  # test or valid
                if entry.get("resumed_from_job_id"):
                    actual_default["job"] = "eval"  # from "kge eval"
                else:
                    actual_default["job"] = "valid"  # child of training job
            else:
                actual_default["job"] = job_type
                if "split" in entry:
                    actual_default["split"] = entry.get("split")
                else:
                    # deprecated
                    actual_default["split"] = entry.get(
                        "data")  # test or valid
            actual_default["job_id"] = job_id.split("-")[0]
            actual_default["model"] = model
            actual_default["reciprocal"] = reciprocal
            # lookup name is in config value is in trace
            actual_default["metric"] = entry.get(
                config.get_default("valid.metric"))
            if args.search:
                actual_default["child_job_id"] = entry.get(
                    "child_job_id").split("-")[0]
            for key in list(actual_default.keys()):
                if key not in default_attributes:
                    del actual_default[key]
            csv_writer.writerow(
                [actual_default[new_key]
                 for new_key in actual_default.keys()] +
                [new_attributes[new_key] for new_key in new_attributes.keys()])
        else:
            entry.update({"reciprocal": reciprocal, "model": model})
            if keymap:
                entry.update(new_attributes)
            print(entry)
    if args.list_keys:
        # only one config needed
        config = configs[list(configs.keys())[0]]
        options = Config.flatten(config.options)
        options = sorted(filter(lambda opt: "+++" not in opt, options),
                         key=lambda opt: opt.lower())
        if isinstance(args.list_keys, bool):
            sep = ", "
        else:
            sep = args.list_keys
        print("Default keys for CSV: ")
        print(*default_attributes.keys(), sep=sep)
        print("")
        print("Special keys: ")
        print(*["$folder", "$checkpoint", "$machine", "$base_model"], sep=sep)
        print("")
        print("Keys found in trace: ")
        print(*sorted(all_trace_keys), sep=sep)
        print("")
        print("Keys found in config: ")
        print(*options, sep=sep)
Exemplo n.º 19
0
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # package command
    if args.command == "package":
        package_model(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print(
                "WARNING: No configuration specified; using " + args.config,
                file=sys.stderr,
            )

        if not vars(args)["console.quiet"]:
            print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        if not vars(args)["console.quiet"]:
            print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            checkpoint_file = get_checkpoint_file(config, args.checkpoint)

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        def get_seed(what):
            seed = config.get(f"random_seed.{what}")
            if seed < 0 and config.get(f"random_seed.default") >= 0:
                import hashlib

                # we add an md5 hash to the default seed so that different PRNGs get a
                # different seed
                seed = (
                    config.get(f"random_seed.default")
                    + int(hashlib.md5(what.encode()).hexdigest(), 16)
                ) % 0xFFFF  # stay 32-bit

            return seed

        if get_seed("python") > -1:
            import random

            random.seed(get_seed("python"))
        if get_seed("torch") > -1:
            import torch

            torch.manual_seed(get_seed("torch"))
        if get_seed("numpy") > -1:
            import numpy.random

            numpy.random.seed(get_seed("numpy"))
        if get_seed("numba") > -1:
            import numpy as np, numba

            @numba.njit
            def seed_numba(seed):
                np.random.seed(seed)

            seed_numba(get_seed("numba"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.create(config)

            # let's go
            if args.command == "resume":
                if checkpoint_file is not None:
                    checkpoint = load_checkpoint(
                        checkpoint_file, config.get("job.device")
                    )
                    job = Job.create_from(
                        checkpoint, new_config=config, dataset=dataset
                    )
                else:
                    job = Job.create(config, dataset)
                    job.config.log(
                        "No checkpoint found or specified, starting from scratch..."
                    )
            else:
                job = Job.create(config, dataset)
            job.run()
    except BaseException:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise
Exemplo n.º 20
0
    def __init__(
        self,
        config: Config,
        dataset: Dataset,
        configuration_key: str,
        vocab_size: int,
        init_for_load_only=False,
    ):
        super().__init__(config,
                         dataset,
                         configuration_key,
                         init_for_load_only=init_for_load_only)

        # read config
        self.config.check("train.trace_level", ["batch", "epoch"])
        self.vocab_size = vocab_size

        if self.get_option("modalities")[0] != "struct":
            raise ValueError("DKRL assumes that struct is the first modality")

        # set relation embedder dim
        # fixes the problem that for the search, relation and entity embeder dim
        # has to be set with a single config
        # CAREFULL: THIS ASSUMES THAT THE ENITY EMBEDER IS CREATED FIRST
        rel_emb_conf_key = configuration_key.replace("entity_embedder",
                                                     "relation_embedder")
        if configuration_key == rel_emb_conf_key:
            raise ValueError("Cannot set the relation embedding size")
        config.set(f"{rel_emb_conf_key}.dim", self.dim)

        # create embedder for each modality
        self.embedder = torch.nn.ModuleDict()
        for modality in self.get_option("modalities"):
            # if dim of modality embedder is < 0 set it to parent embedder dim
            # e.g. when using dkrl, the text embedding dim should equal embedding dim
            # but when using literale, the text embedding dim can vary
            if self.get_option(f"{modality}.dim") < 0:
                config.set(f"{self.configuration_key}.{modality}.dim",
                           self.dim)

            embedder = KgeEmbedder.create(
                config,
                dataset,
                f"{self.configuration_key}.{modality}",
                vocab_size=self.vocab_size,
                init_for_load_only=init_for_load_only)
            self.embedder[modality] = embedder

        # HACK
        # kwargs["indexes"] is set to None, if dkrl_embedder has
        # regularize_args.weighted set to False.
        # If the child_embedder has regularize_args.weighted set to True,
        # it tries to access kwargs["indexes"], which leads to an error

        # Set regularize_args.weighted to True, if it is set for the struct embedder
        if self.embedder["struct"].get_option("regularize_args.weighted"):
            config.set(self.configuration_key + ".regularize_args.weighted",
                       True)

        # TODO handling negative dropout because using it with ax searches for now
        dropout = self.get_option("dropout")
        if dropout < 0:
            if config.get("train.auto_correct"):
                config.log("Setting {}.dropout to 0, "
                           "was set to {}.".format(configuration_key, dropout))
                dropout = 0
        self.dropout = torch.nn.Dropout(dropout)
Exemplo n.º 21
0
def create_parser(config, additional_args=[]):
    # define short option names
    short_options = {
        "dataset.name": "-d",
        "job.type": "-j",
        "train.max_epochs": "-e",
        "model": "-m",
    }

    # create parser for config
    parser_conf = argparse.ArgumentParser(add_help=False)
    for key, value in Config.flatten(config.options).items():
        short = short_options.get(key)
        argtype = type(value)
        if argtype == bool:
            argtype = argparse_bool_type
        if short:
            parser_conf.add_argument("--" + key, short, type=argtype)
        else:
            parser_conf.add_argument("--" + key, type=argtype)

    # add additional arguments
    for key in additional_args:
        parser_conf.add_argument(key)

    # add argument to abort on outdated data
    parser_conf.add_argument(
        "--abort-when-cache-outdated",
        action="store_const",
        const=True,
        default=False,
        help="Abort processing when an outdated cached dataset file is found "
        "(see description of `dataset.pickle` configuration key). "
        "Default is to recompute such cache files.",
    )

    # create main parsers and subparsers
    parser = argparse.ArgumentParser("kge")
    subparsers = parser.add_subparsers(title="command", dest="command")
    subparsers.required = True

    # start and its meta-commands
    parser_start = subparsers.add_parser(
        "start", help="Start a new job (create and run it)", parents=[parser_conf]
    )
    parser_create = subparsers.add_parser(
        "create", help="Create a new job (but do not run it)", parents=[parser_conf]
    )
    for p in [parser_start, parser_create]:
        p.add_argument("config", type=str, nargs="?")
        p.add_argument("--folder", "-f", type=str, help="Output folder to use")
        p.add_argument(
            "--run",
            default=p is parser_start,
            type=argparse_bool_type,
            help="Whether to immediately run the created job",
        )

    # resume and its meta-commands
    parser_resume = subparsers.add_parser(
        "resume", help="Resume a prior job", parents=[parser_conf]
    )
    parser_eval = subparsers.add_parser(
        "eval", help="Evaluate the result of a prior job", parents=[parser_conf]
    )
    parser_valid = subparsers.add_parser(
        "valid",
        help="Evaluate the result of a prior job using validation data",
        parents=[parser_conf],
    )
    parser_test = subparsers.add_parser(
        "test",
        help="Evaluate the result of a prior job using test data",
        parents=[parser_conf],
    )
    for p in [parser_resume, parser_eval, parser_valid, parser_test]:
        p.add_argument("config", type=str)
        p.add_argument(
            "--checkpoint",
            type=str,
            help=(
                "Which checkpoint to use: 'default', 'last', 'best', a number "
                "or a file name"
            ),
            default="default",
        )
    add_dump_parsers(subparsers)
    add_package_parser(subparsers)
    return parser
Exemplo n.º 22
0
    def __init__(self,
                 config: Config,
                 dataset: Dataset,
                 parent_job: Job = None) -> None:
        from kge.job import EvaluationJob

        super().__init__(config, dataset, parent_job)
        self.model: KgeModel = KgeModel.create(config, dataset)
        self.optimizer = KgeOptimizer.create(config, self.model)
        self.kge_lr_scheduler = KgeLRScheduler(config, self.optimizer)
        self.loss = KgeLoss.create(config)
        self.abort_on_nan: bool = config.get("train.abort_on_nan")
        self.batch_size: int = config.get("train.batch_size")
        self.device: str = self.config.get("job.device")
        self.train_split = config.get("train.split")
        valid_conf = config.clone()
        valid_conf.set("job.type", "eval")
        if self.config.get("valid.split") != "":
            valid_conf.set("eval.split", self.config.get("valid.split"))
        valid_conf.set("eval.trace_level",
                       self.config.get("valid.trace_level"))
        self.valid_job = EvaluationJob.create(valid_conf,
                                              dataset,
                                              parent_job=self,
                                              model=self.model)
        self.config.check("train.trace_level", ["batch", "epoch"])
        self.trace_batch: bool = self.config.get(
            "train.trace_level") == "batch"
        self.epoch: int = 0
        self.valid_trace: List[Dict[str, Any]] = []
        self.is_prepared = False
        self.model.train()

        # attributes filled in by implementing classes
        self.loader = None
        self.num_examples = None
        self.type_str: Optional[str] = None

        #: Hooks run after training for an epoch.
        #: Signature: job, trace_entry
        self.post_epoch_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        #: Hooks run before starting a batch.
        #: Signature: job
        self.pre_batch_hooks: List[Callable[[Job], Any]] = []

        #: Hooks run before outputting the trace of a batch. Can modify trace entry.
        #: Signature: job, trace_entry
        self.post_batch_trace_hooks: List[Callable[[Job, Dict[str, Any]],
                                                   Any]] = []

        #: Hooks run before outputting the trace of an epoch. Can modify trace entry.
        #: Signature: job, trace_entry
        self.post_epoch_trace_hooks: List[Callable[[Job, Dict[str, Any]],
                                                   Any]] = []

        #: Hooks run after a validation job.
        #: Signature: job, trace_entry
        self.post_valid_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        #: Hooks run after training
        #: Signature: job, trace_entry
        self.post_train_hooks: List[Callable[[Job, Dict[str, Any]], Any]] = []

        if self.__class__ == TrainingJob:
            for f in Job.job_created_hooks:
                f(self)
Exemplo n.º 23
0
def main():
    # default config
    config = Config()

    # now parse the arguments
    parser = create_parser(config)
    args, unknown_args = parser.parse_known_args()

    # If there where unknown args, add them to the parser and reparse. The correctness
    # of these arguments will be checked later.
    if len(unknown_args) > 0:
        parser = create_parser(
            config, filter(lambda a: a.startswith("--"), unknown_args)
        )
        args = parser.parse_args()

    # process meta-commands
    process_meta_command(args, "create", {"command": "start", "run": False})
    process_meta_command(args, "eval", {"command": "resume", "job.type": "eval"})
    process_meta_command(
        args, "test", {"command": "resume", "job.type": "eval", "eval.split": "test"}
    )
    process_meta_command(
        args, "valid", {"command": "resume", "job.type": "eval", "eval.split": "valid"}
    )
    # dump command
    if args.command == "dump":
        dump(args)
        exit()

    # start command
    if args.command == "start":
        # use toy config file if no config given
        if args.config is None:
            args.config = kge_base_dir() + "/" + "examples/toy-complex-train.yaml"
            print("WARNING: No configuration specified; using " + args.config)

        print("Loading configuration {}...".format(args.config))
        config.load(args.config)

    # resume command
    if args.command == "resume":
        if os.path.isdir(args.config) and os.path.isfile(args.config + "/config.yaml"):
            args.config += "/config.yaml"
        print("Resuming from configuration {}...".format(args.config))
        config.load(args.config)
        config.folder = os.path.dirname(args.config)
        if not config.folder:
            config.folder = "."
        if not os.path.exists(config.folder):
            raise ValueError(
                "{} is not a valid config file for resuming".format(args.config)
            )

    # overwrite configuration with command line arguments
    for key, value in vars(args).items():
        if key in [
            "command",
            "config",
            "run",
            "folder",
            "checkpoint",
            "abort_when_cache_outdated",
        ]:
            continue
        if value is not None:
            if key == "search.device_pool":
                value = "".join(value).split(",")
            try:
                if isinstance(config.get(key), bool):
                    value = argparse_bool_type(value)
            except KeyError:
                pass
            config.set(key, value)
            if key == "model":
                config._import(value)

    # initialize output folder
    if args.command == "start":
        if args.folder is None:  # means: set default
            config_name = os.path.splitext(os.path.basename(args.config))[0]
            config.folder = os.path.join(
                kge_base_dir(),
                "local",
                "experiments",
                datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + "-" + config_name,
            )
        else:
            config.folder = args.folder

    # catch errors to log them
    try:
        if args.command == "start" and not config.init_folder():
            raise ValueError("output folder {} exists already".format(config.folder))
        config.log("Using folder: {}".format(config.folder))

        # determine checkpoint to resume (if any)
        if hasattr(args, "checkpoint"):
            if args.checkpoint == "default":
                if config.get("job.type") in ["eval", "valid"]:
                    checkpoint_file = config.checkpoint_file("best")
                else:
                    checkpoint_file = None  # means last
            elif is_number(args.checkpoint, int) or args.checkpoint == "best":
                checkpoint_file = config.checkpoint_file(args.checkpoint)
            else:
                # otherwise, treat it as a filename
                checkpoint_file = args.checkpoint

        # disable processing of outdated cached dataset files globally
        Dataset._abort_when_cache_outdated = args.abort_when_cache_outdated

        # log configuration
        config.log("Configuration:")
        config.log(yaml.dump(config.options), prefix="  ")
        config.log("git commit: {}".format(get_git_revision_short_hash()), prefix="  ")

        # set random seeds
        if config.get("random_seed.python") > -1:
            import random

            random.seed(config.get("random_seed.python"))
        if config.get("random_seed.torch") > -1:
            import torch

            torch.manual_seed(config.get("random_seed.torch"))
        if config.get("random_seed.numpy") > -1:
            import numpy.random

            numpy.random.seed(config.get("random_seed.numpy"))

        # let's go
        if args.command == "start" and not args.run:
            config.log("Job created successfully.")
        else:
            # load data
            dataset = Dataset.load(config)

            # let's go
            job = Job.create(config, dataset)
            if args.command == "resume":
                job.resume(checkpoint_file)
            job.run()
    except BaseException as e:
        tb = traceback.format_exc()
        config.log(tb, echo=False)
        raise e from None