예제 #1
0
파일: search.py 프로젝트: Allensmile/kge-1
        def copy_to_search_trace(job, trace_entry):
            trace_entry = copy.deepcopy(trace_entry)
            for key in trace_keys:
                # Process deprecated options to some extent. Support key renames, but
                # not value renames.
                actual_key = {key: None}
                _process_deprecated_options(actual_key)
                if len(actual_key) > 1:
                    raise KeyError(
                        f"{key} is deprecated but cannot be handled automatically"
                    )
                actual_key = next(iter(actual_key.keys()))
                value = train_job_config.get(actual_key)
                trace_entry[key] = value

            trace_entry["folder"] = os.path.split(train_job_config.folder)[1]
            metric_value = Trace.get_metric(trace_entry, metric_name)
            trace_entry["metric_name"] = metric_name
            trace_entry["metric_value"] = metric_value
            trace_entry["parent_job_id"] = search_job.job_id
            search_job.config.trace(**trace_entry)
            valid_trace.append(trace_entry)
예제 #2
0
def _dump_trace(args):
    """Execute the 'dump trace' command."""
    if (args.train or args.valid or args.test or args.truncate or args.job_id
            or args.checkpoint or args.batch or args.example) and args.search:
        sys.exit(
            "--search and any of --train, --valid, --test, --truncate, --job_id,"
            " --checkpoint, --batch, --example are mutually exclusive")

    entry_type_specified = True
    if not (args.train or args.valid or args.test or args.search):
        entry_type_specified = False
        args.train = True
        args.valid = True
        args.test = True

    truncate_flag = False
    truncate_epoch = None
    if isinstance(args.truncate, bool) and args.truncate:
        truncate_flag = True
    elif not isinstance(args.truncate, bool):
        if not args.truncate.isdigit():
            sys.exit(
                "Integer argument or no argument for --truncate must be used")
        truncate_epoch = int(args.truncate)

    checkpoint_path = None
    if ".pt" in os.path.split(args.source)[-1]:
        checkpoint_path = args.source
        folder_path = os.path.split(args.source)[0]
    else:
        # determine job_id and epoch from last/best checkpoint automatically
        if args.checkpoint:
            checkpoint_path = Config.best_or_last_checkpoint_file(args.source)
        folder_path = args.source
    if not checkpoint_path and truncate_flag:
        sys.exit(
            "--truncate can only be used as a flag when a checkpoint is specified."
            " Consider specifying a checkpoint or use an integer argument for the"
            " --truncate option")
    if checkpoint_path and args.job_id:
        sys.exit(
            "--job_id cannot be used together with a checkpoint as the checkpoint"
            " already specifies the job_id")
    trace = os.path.join(folder_path, "trace.yaml")
    if not os.path.isfile(trace):
        sys.exit(
            f"No file 'trace.yaml' found at {os.path.abspath(folder_path)}")

    # process additional keys from --keys and --keysfile
    keymap = OrderedDict()
    additional_keys = []
    if args.keysfile:
        with open(args.keysfile, "r") as keyfile:
            additional_keys = keyfile.readlines()
    if args.keys:
        additional_keys += args.keys
    for line in additional_keys:
        line = line.rstrip("\n").replace(" ", "")
        name_key = line.split("=")
        if len(name_key) == 1:
            name_key += name_key
        keymap[name_key[0]] = name_key[1]

    job_id = None
    # use job_id and truncate_epoch from checkpoint
    if checkpoint_path and truncate_flag:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
        truncate_epoch = checkpoint["epoch"]
    # only use job_id from checkpoint
    elif checkpoint_path:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
    # no checkpoint specified job_id might have been set manually
    elif args.job_id:
        job_id = args.job_id
    # don't restrict epoch number in case it has not been specified yet
    if not truncate_epoch:
        truncate_epoch = float("inf")

    entries, job_epochs = [], {}
    if not args.search:
        entries, job_epochs = Trace.grep_training_trace_entries(
            tracefile=trace,
            train=args.train,
            test=args.test,
            valid=args.valid,
            example=args.example,
            batch=args.batch,
            job_id=job_id,
            epoch_of_last=truncate_epoch,
        )
    if not entries and (args.search or not entry_type_specified):
        entries = Trace.grep_entries(tracefile=trace,
                                     conjunctions=[f"scope: train"])
        truncate_epoch = None
        if entries:
            args.search = True
    if not entries and entry_type_specified:
        sys.exit(
            "No relevant trace entries found. If this was a trace from a search"
            " job, dont use any of --train --valid --test.")
    elif not entries:
        sys.exit("No relevant trace entries found.")

    if args.list_keys:
        all_trace_keys = set()

    if not args.yaml:
        csv_writer = csv.writer(sys.stdout)
        # dict[new_name] = (lookup_name, where)
        # if where=="config"/"trace" it will be looked up automatically
        # if where=="sep" it must be added in in the write loop separately
        if args.no_default_keys:
            default_attributes = OrderedDict()
        else:
            default_attributes = OrderedDict([
                ("job_id", ("job_id", "sep")),
                ("dataset", ("dataset.name", "config")),
                ("model", ("model", "sep")),
                ("reciprocal", ("reciprocal", "sep")),
                ("job", ("job", "sep")),
                ("job_type", ("type", "trace")),
                ("split", ("split", "sep")),
                ("epoch", ("epoch", "trace")),
                ("avg_loss", ("avg_loss", "trace")),
                ("avg_penalty", ("avg_penalty", "trace")),
                ("avg_cost", ("avg_cost", "trace")),
                ("metric_name", ("valid.metric", "config")),
                ("metric", ("metric", "sep")),
            ])
            if args.search:
                default_attributes["child_folder"] = ("folder", "trace")
                default_attributes["child_job_id"] = ("child_job_id", "sep")

        if not (args.no_header or args.list_keys):
            csv_writer.writerow(
                list(default_attributes.keys()) +
                [key for key in keymap.keys()])
    # store configs for job_id's s.t. they need to be loaded only once
    configs = {}
    warning_shown = False
    for entry in entries:
        current_epoch = entry.get("epoch")
        job_type = entry.get("job")
        job_id = entry.get("job_id")
        if truncate_epoch and not current_epoch <= float(truncate_epoch):
            continue
        # filter out entries not relevant to the unique training sequence determined
        # by the options; not relevant for search
        if job_type == "train":
            if current_epoch > job_epochs[job_id]:
                continue
        elif job_type == "eval":
            if "resumed_from_job_id" in entry:
                if current_epoch > job_epochs[entry.get(
                        "resumed_from_job_id")]:
                    continue
            elif "parent_job_id" in entry:
                if current_epoch > job_epochs[entry.get("parent_job_id")]:
                    continue
        # find relevant config file
        child_job_id = entry.get(
            "child_job_id") if "child_job_id" in entry else None
        config_key = (entry.get("folder") + "/" +
                      str(child_job_id) if args.search else job_id)
        if config_key in configs.keys():
            config = configs[config_key]
        else:
            if args.search:
                if not child_job_id and not warning_shown:
                    # This warning is from Dec 19, 2019. TODO remove
                    print(
                        "Warning: You are dumping the trace of an older search job. "
                        "This is fine only if "
                        "the config.yaml files in each subfolder have not been modified "
                        "after running the corresponding training job.",
                        file=sys.stderr,
                    )
                    warning_shown = True
                config = get_config_for_job_id(
                    child_job_id, os.path.join(folder_path,
                                               entry.get("folder")))
                entry["type"] = config.get("train.type")
            else:
                config = get_config_for_job_id(job_id, folder_path)
            configs[config_key] = config
        if args.list_keys:
            all_trace_keys.update(entry.keys())
            continue
        new_attributes = OrderedDict()
        # when training was reciprocal, use the base_model as model
        if config.get_default("model") == "reciprocal_relations_model":
            model = config.get_default(
                "reciprocal_relations_model.base_model.type")
            # the string that substitutes $base_model in keymap if it exists
            subs_model = "reciprocal_relations_model.base_model"
            reciprocal = 1
        else:
            model = config.get_default("model")
            subs_model = model
            reciprocal = 0
        # search for the additional keys from --keys and --keysfile
        for new_key in keymap.keys():
            lookup = keymap[new_key]
            # search for special keys
            value = None
            if lookup == "$folder":
                value = os.path.abspath(folder_path)
            elif lookup == "$checkpoint" and checkpoint_path:
                value = os.path.abspath(checkpoint_path)
            elif lookup == "$machine":
                value = socket.gethostname()
            if "$base_model" in lookup:
                lookup = lookup.replace("$base_model", subs_model)
            # search for ordinary keys; start searching in trace entry then config
            if not value:
                value = entry.get(lookup)
            if not value:
                try:
                    value = config.get_default(lookup)
                except:
                    pass  # value stays None; creates empty field in csv
            if value and isinstance(value, bool):
                value = 1
            elif not value and isinstance(value, bool):
                value = 0
            new_attributes[new_key] = value
        if not args.yaml:
            # find the actual values for the default attributes
            actual_default = default_attributes.copy()
            for new_key in default_attributes.keys():
                lookup, where = default_attributes[new_key]
                if where == "config":
                    actual_default[new_key] = config.get(lookup)
                elif where == "trace":
                    actual_default[new_key] = entry.get(lookup)
            # keys with separate treatment
            # "split" in {train,test,valid} for the datatype
            # "job" in {train,eval,valid,search}
            if job_type == "train":
                if "split" in entry:
                    actual_default["split"] = entry.get("split")
                else:
                    actual_default["split"] = "train"
                actual_default["job"] = "train"
            elif job_type == "eval":
                if "split" in entry:
                    actual_default["split"] = entry.get(
                        "split")  # test or valid
                else:
                    # deprecated
                    actual_default["split"] = entry.get(
                        "data")  # test or valid
                if entry.get("resumed_from_job_id"):
                    actual_default["job"] = "eval"  # from "kge eval"
                else:
                    actual_default["job"] = "valid"  # child of training job
            else:
                actual_default["job"] = job_type
                if "split" in entry:
                    actual_default["split"] = entry.get("split")
                else:
                    # deprecated
                    actual_default["split"] = entry.get(
                        "data")  # test or valid
            actual_default["job_id"] = job_id.split("-")[0]
            actual_default["model"] = model
            actual_default["reciprocal"] = reciprocal
            # lookup name is in config value is in trace
            actual_default["metric"] = entry.get(
                config.get_default("valid.metric"))
            if args.search:
                actual_default["child_job_id"] = entry.get(
                    "child_job_id").split("-")[0]
            for key in list(actual_default.keys()):
                if key not in default_attributes:
                    del actual_default[key]
            csv_writer.writerow(
                [actual_default[new_key]
                 for new_key in actual_default.keys()] +
                [new_attributes[new_key] for new_key in new_attributes.keys()])
        else:
            entry.update({"reciprocal": reciprocal, "model": model})
            if keymap:
                entry.update(new_attributes)
            print(entry)
    if args.list_keys:
        # only one config needed
        config = configs[list(configs.keys())[0]]
        options = Config.flatten(config.options)
        options = sorted(filter(lambda opt: "+++" not in opt, options),
                         key=lambda opt: opt.lower())
        if isinstance(args.list_keys, bool):
            sep = ", "
        else:
            sep = args.list_keys
        print("Default keys for CSV: ")
        print(*default_attributes.keys(), sep=sep)
        print("")
        print("Special keys: ")
        print(*["$folder", "$checkpoint", "$machine", "$base_model"], sep=sep)
        print("")
        print("Keys found in trace: ")
        print(*sorted(all_trace_keys), sep=sep)
        print("")
        print("Keys found in config: ")
        print(*options, sep=sep)
예제 #3
0
def _dump_trace(args):
    """ Executes the 'dump trace' command."""
    start = time.time()
    if (args.train or args.valid or args.test) and args.search:
        print(
            "--search and --train, --valid, --test are mutually exclusive",
            file=sys.stderr,
        )
        exit(1)
    entry_type_specified = True
    if not (args.train or args.valid or args.test or args.search):
        entry_type_specified = False
        args.train = True
        args.valid = True
        args.test = True

    checkpoint_path = None
    if ".pt" in os.path.split(args.source)[-1]:
        checkpoint_path = args.source
        folder_path = os.path.split(args.source)[0]
    else:
        # determine job_id and epoch from last/best checkpoint automatically
        if args.checkpoint:
            checkpoint_path = Config.get_best_or_last_checkpoint(args.source)
        folder_path = args.source
        if not args.checkpoint and args.truncate:
            raise ValueError(
                "You can only use --truncate when a checkpoint is specified."
                "Consider using --checkpoint or provide a checkpoint file as source"
            )
    trace = os.path.join(folder_path, "trace.yaml")
    if not os.path.isfile(trace):
        sys.stderr.write("No trace found at {}\n".format(trace))
        exit(1)

    keymap = OrderedDict()
    additional_keys = []
    if args.keysfile:
        with open(args.keysfile, "r") as keyfile:
            additional_keys = keyfile.readlines()
    if args.keys:
        additional_keys += args.keys
    for line in additional_keys:
        line = line.rstrip("\n").replace(" ", "")
        name_key = line.split("=")
        if len(name_key) == 1:
            name_key += name_key
        keymap[name_key[0]] = name_key[1]

    job_id = None
    epoch = int(args.max_epoch)
    # use job_id and epoch from checkpoint
    if checkpoint_path and args.truncate:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
        epoch = checkpoint["epoch"]
    # only use job_id from checkpoint
    elif checkpoint_path:
        checkpoint = torch.load(f=checkpoint_path, map_location="cpu")
        job_id = checkpoint["job_id"]
    # override job_id and epoch with user arguments
    if args.job_id:
        job_id = args.job_id
    if not epoch:
        epoch = float("inf")

    entries, job_epochs = [], {}
    if not args.search:
        entries, job_epochs = Trace.grep_training_trace_entries(
            tracefile=trace,
            train=args.train,
            test=args.test,
            valid=args.valid,
            example=args.example,
            batch=args.batch,
            job_id=job_id,
            epoch_of_last=epoch,
        )
    if not entries and (args.search or not entry_type_specified):
        entries = Trace.grep_entries(tracefile=trace,
                                     conjunctions=[f"scope: train"])
        epoch = None
        if entries:
            args.search = True
    if not entries:
        print("No relevant trace entries found.", file=sys.stderr)
        exit(1)

    middle = time.time()
    if not args.yaml:
        csv_writer = csv.writer(sys.stdout)
        # dict[new_name] = (lookup_name, where)
        # if where=="config"/"trace" it will be looked up automatically
        # if where=="sep" it must be added in in the write loop separately
        if args.no_default_keys:
            default_attributes = OrderedDict()
        else:
            default_attributes = OrderedDict([
                ("job_id", ("job_id", "sep")),
                ("dataset", ("dataset.name", "config")),
                ("model", ("model", "sep")),
                ("reciprocal", ("reciprocal", "sep")),
                ("job", ("job", "sep")),
                ("job_type", ("type", "trace")),
                ("split", ("split", "sep")),
                ("epoch", ("epoch", "trace")),
                ("avg_loss", ("avg_loss", "trace")),
                ("avg_penalty", ("avg_penalty", "trace")),
                ("avg_cost", ("avg_cost", "trace")),
                ("metric_name", ("valid.metric", "config")),
                ("metric", ("metric", "sep")),
            ])
            if args.search:
                default_attributes["child_folder"] = ("folder", "trace")
                default_attributes["child_job_id"] = ("child_job_id", "sep")

        if not args.no_header:
            csv_writer.writerow(
                list(default_attributes.keys()) +
                [key for key in keymap.keys()])
    # store configs for job_id's s.t. they need to be loaded only once
    configs = {}
    warning_shown = False
    for entry in entries:
        if epoch and not entry.get("epoch") <= float(epoch):
            continue
        # filter out not needed entries from a previous job when
        # a job was resumed from the middle
        if entry.get("job") == "train":
            job_id = entry.get("job_id")
            if entry.get("epoch") > job_epochs[job_id]:
                continue

        # find relevant config file
        child_job_id = entry.get(
            "child_job_id") if "child_job_id" in entry else None
        config_key = (entry.get("folder") + "/" + str(child_job_id)
                      if args.search else entry.get("job_id"))
        if config_key in configs.keys():
            config = configs[config_key]
        else:
            if args.search:
                if not child_job_id and not warning_shown:
                    # This warning is from Dec 19, 2019. TODO remove
                    print(
                        "Warning: You are dumping the trace of an older search job. "
                        "This is fine only if "
                        "the config.yaml files in each subfolder have not been modified "
                        "after running the corresponding training job.",
                        file=sys.stderr,
                    )
                    warning_shown = True
                config = get_config_for_job_id(
                    child_job_id, os.path.join(folder_path,
                                               entry.get("folder")))
                entry["type"] = config.get("train.type")
            else:
                config = get_config_for_job_id(entry.get("job_id"),
                                               folder_path)
            configs[config_key] = config

        new_attributes = OrderedDict()
        if config.get_default("model") == "reciprocal_relations_model":
            model = config.get_default(
                "reciprocal_relations_model.base_model.type")
            # the string that substitutes $base_model in keymap if it exists
            subs_model = "reciprocal_relations_model.base_model"
            reciprocal = 1
        else:
            model = config.get_default("model")
            subs_model = model
            reciprocal = 0
        for new_key in keymap.keys():
            lookup = keymap[new_key]
            if "$base_model" in lookup:
                lookup = lookup.replace("$base_model", subs_model)
            try:
                if lookup == "$folder":
                    val = os.path.abspath(folder_path)
                elif lookup == "$checkpoint":
                    val = os.path.abspath(checkpoint_path)
                elif lookup == "$machine":
                    val = socket.gethostname()
                else:
                    val = config.get_default(lookup)
            except:
                # creates empty field if key is not existing
                val = entry.get(lookup)
            if type(val) == bool and val:
                val = 1
            elif type(val) == bool and not val:
                val = 0
            new_attributes[new_key] = val
        if not args.yaml:
            # find the actual values for the default attributes
            actual_default = default_attributes.copy()
            for new_key in default_attributes.keys():
                lookup, where = default_attributes[new_key]
                if where == "config":
                    actual_default[new_key] = config.get(lookup)
                elif where == "trace":
                    actual_default[new_key] = entry.get(lookup)
            # keys with separate treatment
            # "split" in {train,test,valid} for the datatype
            # "job" in {train,eval,valid,search}
            if entry.get("job") == "train":
                actual_default["split"] = "train"
                actual_default["job"] = "train"
            elif entry.get("job") == "eval":
                actual_default["split"] = entry.get("data")  # test or valid
                if entry.get("resumed_from_job_id"):
                    actual_default["job"] = "eval"  # from "kge eval"
                else:
                    actual_default["job"] = "valid"  # child of training job
            else:
                actual_default["job"] = entry.get("job")
                actual_default["split"] = entry.get("data")
            actual_default["job_id"] = entry.get("job_id").split("-")[0]
            actual_default["model"] = model
            actual_default["reciprocal"] = reciprocal
            # lookup name is in config value is in trace
            actual_default["metric"] = entry.get(
                config.get_default("valid.metric"))
            if args.search:
                actual_default["child_job_id"] = entry.get(
                    "child_job_id").split("-")[0]
            for key in list(actual_default.keys()):
                if key not in default_attributes:
                    del actual_default[key]
            csv_writer.writerow(
                [actual_default[new_key]
                 for new_key in actual_default.keys()] +
                [new_attributes[new_key] for new_key in new_attributes.keys()])
        else:
            entry.update({"reciprocal": reciprocal, "model": model})
            if keymap:
                entry.update(new_attributes)
            sys.stdout.write(re.sub("[{}']", "", str(entry)))
            sys.stdout.write("\n")
    end = time.time()
    if args.timeit:
        sys.stdout.write("Grep + processing took {} \n".format(middle - start))
        sys.stdout.write("Writing took {}".format(end - middle))
예제 #4
0
파일: search.py 프로젝트: AdrianKs/dist-kge
def _run_train_job(sicnk, device=None):
    """Runs a training job and returns the trace entry of its best validation result.

    Also takes are of appropriate tracing.

    """

    search_job, train_job_index, train_job_config, train_job_count, trace_keys = sicnk

    try:
        # load the job
        if device is not None:
            train_job_config.set("job.device", device)
        search_job.config.log(
            "Starting training job {} ({}/{}) on device {}...".format(
                train_job_config.folder,
                train_job_index + 1,
                train_job_count,
                train_job_config.get("job.device"),
            ))

        # process the trace entries to far (in case of a resumed job)
        metric_name = search_job.config.get("valid.metric")
        valid_trace = []

        def copy_to_search_trace(job, trace_entry=None):
            if trace_entry is None:
                trace_entry = job.valid_trace[-1]
            trace_entry = copy.deepcopy(trace_entry)
            for key in trace_keys:
                # Process deprecated options to some extent. Support key renames, but
                # not value renames.
                actual_key = {key: None}
                _process_deprecated_options(actual_key)
                if len(actual_key) > 1:
                    raise KeyError(
                        f"{key} is deprecated but cannot be handled automatically"
                    )
                actual_key = next(iter(actual_key.keys()))
                value = train_job_config.get(actual_key)
                trace_entry[key] = value

            trace_entry["folder"] = os.path.split(train_job_config.folder)[1]
            metric_value = Trace.get_metric(trace_entry, metric_name)
            trace_entry["metric_name"] = metric_name
            trace_entry["metric_value"] = metric_value
            trace_entry["parent_job_id"] = search_job.job_id
            search_job.config.trace(**trace_entry)
            valid_trace.append(trace_entry)

        checkpoint_file = get_checkpoint_file(train_job_config)
        if checkpoint_file is not None:
            checkpoint = load_checkpoint(checkpoint_file,
                                         train_job_config.get("job.device"))
            job = Job.create_from(
                checkpoint=checkpoint,
                new_config=train_job_config,
                dataset=search_job.dataset,
                parent_job=search_job,
            )
            for trace_entry in job.valid_trace:
                copy_to_search_trace(None, trace_entry)
        else:
            if train_job_config.get("job.distributed.num_workers") > 0:
                from kge.distributed.funcs import create_and_run_distributed
                valid_trace = create_and_run_distributed(
                    config=train_job_config, dataset=search_job.dataset)
            else:
                job = Job.create(
                    config=train_job_config,
                    dataset=search_job.dataset,
                    parent_job=search_job,
                )

        if train_job_config.get("job.distributed.num_workers") <= 0:
            valid_trace = []

            # run the job (adding new trace entries as we go)
            # TODO make this less hacky (easier once integrated into SearchJob)
            from kge.job import ManualSearchJob

            if not isinstance(search_job, ManualSearchJob
                              ) or search_job.config.get("manual_search.run"):
                job.post_valid_hooks.append(copy_to_search_trace)
                job.run()
            else:
                search_job.config.log(
                    "Skipping running of training job as requested by user.")
                return (train_job_index, None, None)

        # analyze the result
        search_job.config.log("Best result in this training job:")
        best = None
        best_metric = None
        for trace_entry in valid_trace:
            if train_job_config.get("job.distributed.num_workers") <= 0:
                metric = trace_entry["metric_value"]
            else:
                # we can not this via post valid hook in distributed setting
                # can not pickle function copy_to_search_trace
                metric = Trace.get_metric(trace_entry, metric_name)
                trace_entry["metric_value"] = metric
                trace_entry["metric_name"] = metric_name
            if not best or Metric(search_job).better(metric, best_metric):
                best = trace_entry
                best_metric = metric

        # record the best result of this job
        best["child_job_id"] = best["job_id"]
        for k in ["job", "job_id", "type", "parent_job_id", "scope", "event"]:
            if k in best:
                del best[k]
        search_job.trace(
            event="search_completed",
            echo=True,
            echo_prefix="  ",
            log=True,
            scope="train",
            **best,
        )

        # force releasing the GPU memory of the job to avoid memory leakage
        del job
        gc.collect()

        return (train_job_index, best, best_metric)
    except BaseException as e:
        search_job.config.log("Trial {:05d} failed: {}".format(
            train_job_index, repr(e)))
        if search_job.on_error == "continue":
            return (train_job_index, None, None)
        else:
            search_job.config.log(
                "Aborting search due to failure of trial {:05d}".format(
                    train_job_index))
            raise e