Exemple #1
0
def train_model(model: Any,
                config_dict: Dict[str, Dict[str, Any]],
                datasets: dict = None) -> Dict[str, Any]:

    # TODO: option to reinitialize model?

    # unpack configurations
    model_cdict: Dict[str, Any] = config_dict["model"]
    meta_cdict: Dict[str, Any] = config_dict["meta"]
    log_cdict: Dict[str, Any] = config_dict["logging"]
    data_cdict: Dict[str, Any] = config_dict["data"]
    hp_cdict: Dict[str, Any] = config_dict["hyper_parameters"]

    perf_cdict: Dict[str, Any] = config_dict["performance"]
    optim_cdict: Dict[str, Any] = config_dict["optimize"]
    cb_cdict: Dict[str, Any] = config_dict["callbacks"]

    return_dict = {}

    full_exp_path = (pathlib.Path(meta_cdict["yeahml_dir"]).joinpath(
        meta_cdict["data_name"]).joinpath(
            meta_cdict["experiment_name"]).joinpath(model_cdict["name"]))

    # build paths and obtain tb writers
    model_run_path = create_model_run_path(full_exp_path)
    # profile_path = model_run_path.joinpath("tf_logs").joinpath("profile")
    save_model_path, save_best_param_path = create_model_training_paths(
        model_run_path)
    tr_writer, v_writer = get_tb_writers(model_run_path)
    log_model_params(tr_writer, 0, model)

    logger = config_logger(model_run_path, log_cdict, "train")
    # get datasets
    # train_ds, val_ds = get_datasets(datasets, data_cdict, hp_cdict)
    dataset_dict = get_datasets(datasets, data_cdict, hp_cdict)

    # {optimizer_name: {"optimizer": tf.obj, "objective": [objective_name]}}
    optimizers_dict = get_optimizers(optim_cdict)

    # {objective_name: "in_config": {...}, "loss": {...}, "metric": {...}}
    # TODO: "train", "val" should be obtained from the config
    objectives_dict = get_objectives(perf_cdict["objectives"],
                                     dataset_dict,
                                     target_splits=["train", "val"])

    # create callbacks
    custom_callbacks = get_callbacks(cb_cdict)
    cbc = CBC(
        custom_callbacks,
        optimizer_names=list(optimizers_dict.keys()),
        dataset_names=list(dataset_dict.keys()),
        objective_names=list(objectives_dict.keys()),
    )
    # TODO: call all cbc methods at the appropriate time

    # create a tf.function for applying gradients for each optimizer
    # TODO: I am not 100% about this logic for maping the optimizer to the
    #   apply_gradient fn... this needs to be confirmed to work as expected
    opt_to_validation_fn = {}
    opt_to_get_grads_fn, opt_to_app_grads_fn = {}, {}
    opt_to_steps = {}
    # used to determine which objectives to loop to calculate losses
    opt_to_loss_objectives = {}
    # used to determine which objectives to obtain to calculate metrics
    opt_to_metrics_objectives = {}

    for cur_optimizer_name, cur_optimizer_config in optimizers_dict.items():

        # TODO: check config to see which fn to get supervised/etc
        opt_to_get_grads_fn[cur_optimizer_name] = get_get_supervised_grads_fn()
        opt_to_app_grads_fn[cur_optimizer_name] = get_apply_grad_fn()
        opt_to_validation_fn[cur_optimizer_name] = get_validation_step_fn()
        opt_to_steps[cur_optimizer_name] = 0

        loss_objective_names = []
        metrics_objective_names = []
        for cur_objective in cur_optimizer_config["objectives"]:
            cur_objective_dict = objectives_dict[cur_objective]
            if "loss" in cur_objective_dict.keys():
                if cur_objective_dict["loss"]:
                    loss_objective_names.append(cur_objective)
            if "metrics" in cur_objective_dict.keys():
                if cur_objective_dict["metrics"]:
                    metrics_objective_names.append(cur_objective)
        opt_to_loss_objectives[cur_optimizer_name] = loss_objective_names
        opt_to_metrics_objectives[cur_optimizer_name] = metrics_objective_names

    # TODO: training_directive may be empty.
    # {
    #     "YEAHML_1": {"optimizers": ["YEAHML_0", "second_opt"], "operation": "&"},
    #     "YEAHML_0": {"optimizers": ["main_opt", "second_opt"], "operation": ","},
    # }
    # TODO: I will need to parse this to create a cleaner directive to follow
    # training_directive = optim_cdict["directive"]

    main_tracker_dict = create_full_dict(
        optimizers_dict=optimizers_dict,
        objectives_dict=objectives_dict,
        datasets_dict=dataset_dict,
    )

    dataset_iter_dict = convert_to_single_pass_iterator(dataset_dict)

    # TODO: create list order of directives to loop through -- I no longer know
    # that this is the best approach -- that is, this should be adaptive  and
    # learned during training and is related to  'how do I determine how "long"'
    # to go here... I think the 'right' answer is dependent on the losses (train
    # and val), but I think there is a short answer as well.
    # TODO: this needs to be driven by the directive, not just a walkthrough

    obj_ds_to_epoch = {}

    # initialize to True
    is_training = True
    num_training_ops = 0
    # a core issue here is that we're doing this entire loop for a single batch
    # NOTE: consider changing is_training to `switch_optimizer`

    # dictionary to keep track of what optimizers are still training on what
    # datasets
    opt_obj_ds_to_training = {}
    for opt_name, opt_conf in optimizers_dict.items():
        opt_obj_ds_to_training[opt_name] = {}
        loss_objective_names = opt_to_loss_objectives[opt_name]
        for ln in loss_objective_names:
            opt_obj_ds_to_training[opt_name][ln] = {}
            ds_name = objectives_dict[ln]["in_config"]["dataset"]
            # init all to True
            # currently there is only one ds per objective
            opt_obj_ds_to_training[opt_name][ln][ds_name] = {"train": True}

    # TODO: this is hardcoded for supervised settings
    # tf.keras models output the model outputs in a list, we need to get the
    # of each prediction we care about from that output to use in the loss
    # function
    # NOTE: I'm not sure how I feel about this -- is it better to have multiple
    # "tf.models" that share params (is that even possible) -- or is it better
    # to do this where it is one "tf.model"?
    if isinstance(model.output, list):
        MODEL_OUTPUT_ORDER = [n.name.split("/")[0] for n in model.output]
        objective_to_output_index = {}
        for obj_name, obj_dict in objectives_dict.items():
            try:
                pred_name = obj_dict["in_config"]["options"]["prediction"]
                out_index = MODEL_OUTPUT_ORDER.index(pred_name)
                objective_to_output_index[obj_name] = out_index
            except KeyError:
                # TODO: perform check later
                objective_to_output_index[obj_name] = None
    else:
        # TODO: this is hardcoded to assume supervised
        objective_to_output_index = {}
        for obj_name, obj_dict in objectives_dict.items():
            objective_to_output_index[obj_name] = None

    list_of_optimizers = list(optimizers_dict.keys())

    logger.info("START - training")
    while is_training:
        cur_optimizer_name = select_optimizer(list_of_optimizers)
        cur_optimizer_config = optimizers_dict[cur_optimizer_name]
        logger.info(f"optimizer: {cur_optimizer_name}")
        continue_optimizer = True
        # apply_current_optimizer is used to remain using a single optimizer

        # get optimizer
        cur_tf_optimizer = cur_optimizer_config["optimizer"]

        # loss
        # opt_name :loss :main_obj :ds_name :split_name :loss_name:desc_name
        # opt_name :metric :main_obj: ds_name :split_name :metric_name
        opt_tracker_dict = main_tracker_dict[cur_optimizer_name]

        # NOTE: if there are multiple objectives, they will be trained *jointly*
        # cur_optimizer_config:
        #   {'optimizer': <tf.opt{}>, 'objectives': ['main_obj']}
        # cur_apply_grad_fn = opt_name_to_gradient_fn[cur_optimizer_name]
        get_grads_fn = opt_to_get_grads_fn[cur_optimizer_name]
        apply_grads_fn = opt_to_app_grads_fn[cur_optimizer_name]

        # TODO: these should really be grouped by the in config (likely by
        # creating a hash) this allows us to group objectives by what
        # dataset their using so that we can reuse the same batch.
        # NOTE: for now, I'm saving the prediction and gt (if supervised) in
        # the grad_dict
        loss_objective_names = opt_to_loss_objectives[cur_optimizer_name]
        metrics_objective_names = opt_to_metrics_objectives[cur_optimizer_name]

        obj_to_grads = {}
        # TODO: the losses should be grouped by the ds used so that we only
        # obtain+run the batch once+ensuring it's the same batch
        loss_update_dict, update_metrics_dict = {}, {}
        while continue_optimizer:
            cur_objective = select_objective(loss_objective_names)
            logger.info(f"objective: {cur_objective}")
            continue_objective = True

            # TODO: next step -- continue_objective = True
            # each loss may be being optimized by data from different datasets
            cur_ds_name = objectives_dict[cur_objective]["in_config"][
                "dataset"]
            loss_conf = objectives_dict[cur_objective]["loss"]
            tf_train_loss_descs_to_update = get_losses_to_update(
                loss_conf, "train")

            cur_train_iter = get_train_iter(dataset_iter_dict, cur_ds_name,
                                            "train")

            while continue_objective:
                cur_batch = get_next_batch(cur_train_iter)
                if not cur_batch:

                    # dataset pass is complete
                    obj_ds_to_epoch = update_epoch_dict(
                        obj_ds_to_epoch, cur_objective, cur_ds_name, "train")

                    if (obj_ds_to_epoch[cur_objective][cur_ds_name]["train"] >=
                            hp_cdict["epochs"]):

                        # update this particular combination to false -
                        # eventually this logic will be "smarter" i.e. not
                        # based entirely on number of epochs.
                        opt_obj_ds_to_training[cur_optimizer_name][
                            cur_objective][cur_ds_name]["train"] = False

                        # this objective is done. see if they're all done
                        is_training = determine_if_training(
                            opt_obj_ds_to_training)

                        # TODO: this isn't the "best" way to handle this,
                        # ideally, we would decided (in an intelligent way) when
                        # we're done training a group of objectives by
                        # evaluating the loss curves
                        list_of_optimizers.remove(cur_optimizer_name)
                        logger.info(
                            f"{cur_optimizer_name} removed from list of opt. remaining: {list_of_optimizers}"
                        )
                        logger.info(f"is_training: {is_training}")
                        # TODO: determine whether to move to the next objective
                        # NOTE: currently, move to the next objective
                        if not is_training:
                            # need to break from all loops
                            continue_optimizer = False
                            continue_objective = False

                        # TODO: there is likely a better way to handle the case
                        # where we have reached the 'set' number of epochs for
                        # this problem

                    # the original dict is updated here in case another dataset
                    # needs to use the datset iter -- this could likely be
                    # optimized, but the impact would be minimal right now
                    cur_train_iter = re_init_iter(cur_ds_name, "train",
                                                  dataset_dict)
                    dataset_iter_dict[cur_ds_name]["train"] = cur_train_iter

                    logger.info(
                        f"epoch {cur_objective} - {cur_ds_name} {'train'}:"
                        f" {obj_ds_to_epoch[cur_objective][cur_ds_name]['train']}"
                    )

                    # perform validation after each pass through the training
                    # dataset
                    # NOTE: the location of this 'validation' may change
                    # TODO: there is an error here where the first objective
                    # will be validated on the last epoch and then one more
                    # time.
                    # TODO: ensure the metrics are reset
                    #  iterate validation after iterating entire training..
                    # this will/should change to update on a set frequency --
                    # also, maybe we don't want to run the "full" validation,
                    # only a (random) subset?

                    # validation pass
                    cur_val_update = inference_dataset(
                        model,
                        loss_objective_names,
                        metrics_objective_names,
                        dataset_iter_dict,
                        opt_to_validation_fn[cur_optimizer_name],
                        opt_tracker_dict,
                        cur_objective,
                        cur_ds_name,
                        dataset_dict,
                        opt_to_steps[cur_optimizer_name],
                        num_training_ops,
                        objective_to_output_index,
                        objectives_dict,
                        v_writer,
                        logger,
                        split_name="val",
                    )

                    # log params used during validation in other location
                    log_model_params(v_writer, num_training_ops, model)

                    # TODO: has run entire ds -- for now, time to break out of
                    # this ds eventually, something smarter will need to be done
                    # here in the training loop, not just after an epoch
                    continue_objective = False

                else:

                    grad_dict = get_grads_fn(
                        model,
                        cur_batch,
                        loss_conf["object"],
                        objective_to_output_index[cur_objective],
                        tf_train_loss_descs_to_update,
                    )
                    # grad_dict contains {
                    #     "gradients": grads,
                    #     "predictions": prediction,
                    #     "final_loss": final_loss,
                    #     "losses": loss,
                    # }

                    # TODO: see note above about ensuring the same batch is used for
                    # losses with the same dataset specified
                    opt_to_steps[cur_optimizer_name] += cur_batch[0].shape[0]
                    num_training_ops += 1
                    # if num_training_ops > 5:
                    #     start_profiler(profile_path, profiling)
                    # elif num_training_ops > 10:
                    #     stop_profiler()

                    # TODO: currently this only stores the last grad dict per objective
                    obj_to_grads[cur_objective] = grad_dict

                    # NOTE: the steps here aren't accurate (due to note above about)
                    # using the same batches for objectives/losses that specify the
                    # same datasets
                    # update_tf_loss_descriptions(
                    #     grad_dict, tf_train_loss_descs_to_update
                    # )
                    # # TODO: add to tensorboard

                    # create histograms of model parameters
                    if log_cdict["track"]["tensorboard"]["param_steps"] > 0:
                        if (num_training_ops % log_cdict["track"]
                            ["tensorboard"]["param_steps"] == 0):
                            log_model_params(tr_writer, num_training_ops,
                                             model)

                    # update Tracker
                    if log_cdict["track"]["tracker_steps"] > 0:
                        if num_training_ops % log_cdict["track"][
                                "tracker_steps"] == 0:
                            cur_loss_tracker_dict = opt_tracker_dict[
                                cur_objective]["loss"][cur_ds_name]["train"]
                            cur_loss_update = update_loss_trackers(
                                loss_conf["track"]["train"],
                                cur_loss_tracker_dict,
                                opt_to_steps[cur_optimizer_name],
                                num_training_ops,
                                tb_writer=tr_writer,
                                ds_name=cur_ds_name,
                                objective_name=cur_objective,
                            )

                            loss_update_dict[cur_objective] = cur_loss_update

                    # TODO: this is a hacky way of seeing if training on a batch was run
                    if obj_to_grads:
                        update_model_params(apply_grads_fn, obj_to_grads,
                                            model, cur_tf_optimizer)

                        update_metric_objects(
                            metrics_objective_names,
                            objectives_dict,
                            obj_to_grads,
                            "train",
                        )

                        if log_cdict["track"]["tracker_steps"] > 0:
                            if (num_training_ops %
                                    log_cdict["track"]["tracker_steps"] == 0):
                                update_metrics_dict = update_metrics_tracking(
                                    metrics_objective_names,
                                    objectives_dict,
                                    opt_tracker_dict,
                                    obj_to_grads,
                                    opt_to_steps[cur_optimizer_name],
                                    num_training_ops,
                                    "train",
                                    tb_writer=tr_writer,
                                    ds_name=cur_ds_name,
                                    objective_name=cur_objective,
                                )

                update_dict = {
                    "loss": loss_update_dict,
                    "metrics": update_metrics_dict
                }
            continue_optimizer = False
        # one pass of training (a batch from each objective) with the
        # current optimizer

    # TODO: I think the 'joint' should likely be the optimizer name, not the
    # combination of losses name, this would also simplify the creation of these

    return_dict = {"tracker": main_tracker_dict}

    return return_dict
Exemple #2
0
def eval_model(
        model: Any,
        config_dict: Dict[str, Dict[str, Any]],
        datasets: Any = None,
        weights_path: str = "",
        eval_split="test",
        pred_dict=None,  # stupid hacky fix
) -> Dict[str, Any]:

    # TODO: allow for multiple splits to evaluate on

    # TODO: load the best weights
    # model = load_targeted_weights(full_exp_path, weights_path)

    # NOTE: should I reset the metrics?
    # # reset metrics (should already be reset)
    # for eval_metric_fn in eval_metric_fns:
    #     eval_metric_fn.reset_states()

    # # TODO: log each instance

    # unpack configurations
    model_cdict: Dict[str, Any] = config_dict["model"]

    # set up loop for performing inference more efficiently
    perf_cdict: Dict[str, Any] = config_dict["performance"]
    ds_to_chash, chash_to_in_config = create_ds_to_lm_mapping(perf_cdict)

    # obtain datasets
    # TODO: hyperparams (depending on implementation) may not be relevant here
    data_cdict: Dict[str, Any] = config_dict["data"]
    hp_cdict: Dict[str, Any] = config_dict["hyper_parameters"]
    dataset_dict = get_datasets(datasets, data_cdict, hp_cdict)
    dataset_iter_dict = convert_to_single_pass_iterator(dataset_dict)

    # obtain logger
    log_cdict: Dict[str, Any] = config_dict["logging"]
    meta_cdict: Dict[str, Any] = config_dict["meta"]
    full_exp_path = (pathlib.Path(meta_cdict["yeahml_dir"]).joinpath(
        meta_cdict["data_name"]).joinpath(
            meta_cdict["experiment_name"]).joinpath(model_cdict["name"]))
    # build paths and obtain tb writers
    model_run_path = create_model_run_path(full_exp_path)
    logger = config_logger(model_run_path, log_cdict, "eval")

    # create output index
    chash_to_output_index = create_output_index(model, chash_to_in_config)

    # objectives to objects
    # TODO: "test" should be obtained from the config
    # this returns a in_config, which isn't really needed.
    # TODO: is this always only going to be a single split?
    split_name = eval_split  # TODO: this needs to be double checked
    objectives_to_objects = get_objectives(perf_cdict["objectives"],
                                           dataset_dict,
                                           target_splits=split_name)

    logger.info("START - evaluating")
    ret_dict = {}
    for cur_ds_name, chash_conf_d in ds_to_chash.items():
        ret_dict[cur_ds_name] = {}
        logger.info(f"current dataset: {cur_ds_name}")

        for in_hash, cur_hash_conf in chash_conf_d.items():
            logger.info(f"in_hash: {in_hash}")
            ret_dict[cur_ds_name][in_hash] = {}
            cur_objective_config = chash_to_in_config[in_hash]
            assert (
                cur_objective_config["type"] == "supervised"
            ), f"only supervised is currently allowed, not {cur_objective_config['type']} :("
            logger.info(f"current config: {cur_objective_config}")

            cur_inference_fn = cur_hash_conf["inference_fn"]
            cur_metrics_objective_names = cur_hash_conf["metric"]
            cur_loss_objective_names = cur_hash_conf["loss"]
            cur_dataset_iter = dataset_iter_dict[cur_ds_name][split_name]
            cur_pred_index = chash_to_output_index[in_hash]
            cur_target_name = cur_objective_config["options"]["target"]

            temp_ret = inference_on_ds(
                model,
                cur_dataset_iter,
                cur_inference_fn,
                cur_loss_objective_names,
                cur_metrics_objective_names,
                objectives_to_objects,
                cur_pred_index,
                cur_target_name,
                eval_split,
                logger,
                pred_dict,
            )
            ret_dict[cur_ds_name][in_hash] = temp_ret

            # reinitialize validation iterator
            dataset_iter_dict[cur_ds_name][split_name] = re_init_iter(
                cur_ds_name, split_name, dataset_dict)

    return ret_dict
Exemple #3
0
def create_configs(main_path: str) -> dict:

    # parse + validate
    config_dict = ccm.generate(main_path, TEMPLATE)

    # TODO: bandaid fix
    if "callbacks" not in config_dict.keys():
        config_dict["callbacks"] = {"objects": {}}

    # custom parsers
    config_dict["model"]["layers"] = layers_parser()(
        config_dict["model"]["layers"])
    config_dict["performance"]["objectives"] = performances_parser()(
        config_dict["performance"]["objectives"])

    # TODO: ---- below
    model_hash = make_hash(config_dict["model"], MODEL_IGNORE_HASH_KEYS)
    config_dict["model"]["model_hash"] = model_hash

    full_exp_path = (Path(config_dict["meta"]["yeahml_dir"]).joinpath(
        config_dict["meta"]["data_name"]).joinpath(
            config_dict["meta"]["experiment_name"]).joinpath(
                config_dict["model"]["name"]))
    logger = config_logger(full_exp_path, config_dict["logging"], "config")

    exp_root_dir = (Path(config_dict["meta"]["yeahml_dir"]).joinpath(
        config_dict["meta"]["data_name"]).joinpath(
            config_dict["meta"]["experiment_name"]))

    try:
        override_yml_dir = config_dict["meta"]["start_fresh"]
    except KeyError:
        # leave existing model information
        override_yml_dir = False

    if os.path.exists(exp_root_dir):
        if override_yml_dir:
            shutil.rmtree(exp_root_dir)
            logger.info(f"directory {exp_root_dir} removed")

    if not os.path.exists(exp_root_dir):
        Path(exp_root_dir).mkdir(parents=True, exist_ok=True)
        logger.info(f"directory {exp_root_dir} created")

    model_root_dir = exp_root_dir.joinpath(config_dict["model"]["name"])
    try:
        override_model_dir = config_dict["model"]["start_fresh"]
    except KeyError:
        # leave existing model information
        override_model_dir = False

    _maybe_create_dir(model_root_dir,
                      wipe_dirs=override_model_dir,
                      logger=logger)

    # build the order of inputs into the model. This logic will likely need to
    # change as inputs become more complex
    input_order = []
    for ds_name, ds_config in config_dict["data"]["datasets"].items():
        for feat_name, config in ds_config["in"].items():
            if config["startpoint"]:
                if not config["label"]:
                    input_order.append(feat_name)
    if not input_order:
        raise ValueError("no inputs have been specified to the model")

    # loop model to ensure all outputs are accounted for
    output_order = []
    for name, config in config_dict["model"]["layers"].items():
        if config["endpoint"]:
            output_order.append(name)
    if not output_order:
        raise ValueError("no outputs have been specified for the model")

    # TODO: maybe this should be a dictionary
    # TODO: this is a sneaky way + band-aid of ensuring we don't specify inputs
    # if they are named the same -- in reality this does not address the root
    # issue, that is that we should be able to allow some intermediate layers to
    # accept input from either layer_a or layer_b, not only layer_a
    input_order = list(set(input_order))

    config_dict["model_io"] = {"inputs": input_order, "outputs": output_order}

    # validate graph
    graph_dict, graph_dependencies = static_analysis(config_dict)
    config_dict["graph_dict"] = graph_dict
    config_dict["graph_dependencies"] = graph_dependencies

    return config_dict
Exemple #4
0
def _primary_config(main_path: str) -> dict:
    main_config_raw = get_raw_dict_from_string(main_path)
    cur_keys = main_config_raw.keys()
    invalid_keys = []
    for key in CONFIG_KEYS:
        if key not in cur_keys:
            invalid_keys.append(key)
            # not all of these *need* to be present, but for now that will be
            # enforced
    if invalid_keys:
        raise ValueError(
            f"The main config does not contain the key(s) {invalid_keys}:"
            f" current keys: {cur_keys}")

    # build dict containing configs
    config_dict = {}
    for config_type in CONFIG_KEYS:
        # try block?
        raw_config = main_config_raw[config_type]
        raw_config = _maybe_extract_from_path(raw_config)

        formatted_config = parse_default(raw_config,
                                         DEFAULT_CONFIG[f"{config_type}"])
        if config_type == "model":
            model_hash = make_hash(formatted_config, IGNORE_HASH_KEYS)
            formatted_config["model_hash"] = model_hash

        config_dict[config_type] = formatted_config

    full_exp_path = (Path(config_dict["meta"]["yeahml_dir"]).joinpath(
        config_dict["meta"]["data_name"]).joinpath(
            config_dict["meta"]["experiment_name"]).joinpath(
                config_dict["model"]["name"]))
    logger = config_logger(full_exp_path, config_dict["logging"], "config")

    unused_keys = check_for_unused_keys(config_dict, main_config_raw, [], [])
    if unused_keys:
        _maybe_message(unused_keys, main_config_raw, logger)

    # TODO: this should probably be made once and stored? in the :meta?
    exp_root_dir = (Path(config_dict["meta"]["yeahml_dir"]).joinpath(
        config_dict["meta"]["data_name"]).joinpath(
            config_dict["meta"]["experiment_name"]))

    try:
        override_yml_dir = config_dict["meta"]["start_fresh"]
    except KeyError:
        # leave existing model information
        override_yml_dir = False

    if os.path.exists(exp_root_dir):
        if override_yml_dir:
            shutil.rmtree(exp_root_dir)
    if not os.path.exists(exp_root_dir):
        Path(exp_root_dir).mkdir(parents=True, exist_ok=True)

    model_root_dir = exp_root_dir.joinpath(config_dict["model"]["name"])
    try:
        override_model_dir = config_dict["model"]["start_fresh"]
    except KeyError:
        # leave existing model information
        override_model_dir = False

    _create_exp_dir(model_root_dir, wipe_dirs=override_model_dir)

    return config_dict
Exemple #5
0
def build_model(config_dict: Dict[str, Dict[str, Any]]) -> Any:

    # unpack configuration
    model_cdict: Dict[str, Any] = config_dict["model"]
    meta_cdict: Dict[str, Any] = config_dict["meta"]
    log_cdict: Dict[str, Any] = config_dict["logging"]
    # data_cdict: Dict[str, Any] = config_dict["data"]
    graph_dict: Dict[str, Any] = config_dict["graph_dict"]
    graph_dependencies: Dict[str, Any] = config_dict["graph_dependencies"]

    model_io_cdict: Dict[str, Any] = config_dict["model_io"]

    full_exp_path = (Path(meta_cdict["yeahml_dir"]).joinpath(
        meta_cdict["data_name"]).joinpath(
            meta_cdict["experiment_name"]).joinpath(model_cdict["name"]))
    logger = config_logger(full_exp_path, log_cdict, "build")
    logger.info("-> START building graph")

    try:
        reset_graph_deterministic(meta_cdict["seed"])
    except KeyError:
        reset_graph()

    # g_logger = config_logger(full_exp_path, log_cdict, "graph")

    # configure/build all layers and save in lookup table
    built_nodes = {}
    # {"<layer_name>": {"func": <layer_func>, "out": <output_of_layer>}}
    for name, node in graph_dict.items():

        node_config = get_node_config_by_name(node.name, config_dict)
        if not node_config:
            raise ValueError(
                f"layer {name} can't be found in {node.config_location}")
        blueprint = node_config["object_dict"]

        if node.startpoint:
            if node.label:
                pass
            else:
                func = None
                out = _configure_input(name, blueprint)
        else:
            func = _configure_layer(name, blueprint)
            out = None

        # TODO: this is a quick fix. the issue is that a node that is a label,
        # does not need to be built as a layer in the graph -- it is only used
        # as a target during training and therefore does not need to be included
        # here
        if not node.label:
            built_nodes[name] = {"out": out, "func": func}

    for group_of_nodes in graph_dependencies:
        list_of_nodes = list(group_of_nodes)
        for cur_node in list_of_nodes:
            if cur_node:
                if not graph_dict[cur_node].label:
                    if not _is_valid_output(built_nodes[cur_node]["out"]):
                        # create the output (it doesn't exist yet)
                        in_names = graph_dict[cur_node].in_name
                        prev_outputs = []
                        for in_name in in_names:
                            prev_out = built_nodes[in_name]["out"]
                            prev_outputs.append(prev_out)

                        # if only one previous output is present, remove list
                        if len(prev_outputs) == 1:
                            prev_outputs = prev_outputs[0]

                        # connect
                        cur_out = built_nodes[cur_node]["func"](prev_outputs)
                        built_nodes[cur_node]["out"] = cur_out
                    else:
                        pass

    model_input_tensors = []
    for name in model_io_cdict["inputs"]:
        try:
            node_d = built_nodes[name]
        except KeyError:
            raise KeyError(
                f"{name} not found in built nodes when creating inputs")

        try:
            out = node_d["out"]
        except KeyError:
            raise KeyError(
                f"out was not created for {name} when creating tensor inputs")
        model_input_tensors.append(out)
    if not model_input_tensors:
        raise ValueError(f"not model inputs are available")

    model_output_tensors = []
    for name in model_io_cdict["outputs"]:
        try:
            node_d = built_nodes[name]
        except KeyError:
            raise KeyError(
                f"{name} not found in built nodes when creating outputs")

        try:
            out = node_d["out"]
        except KeyError:
            raise KeyError(
                f"out was not created for {name} when creating tensor outputs")
        model_output_tensors.append(out)
    if not model_output_tensors:
        raise ValueError(f"not model outputs are available")

    # ---------------------------------------------

    # TODO: inputs may be more complex than an ordered list
    # TODO: outputs could be a list
    # TODO: right now it is assumed that the last layer defined in the config is the
    # output layer -- this may not be true. named outputs would be better.
    model = tf.keras.Model(
        inputs=model_input_tensors,
        outputs=model_output_tensors,
        name=model_cdict["name"],
    )

    # write meta.json including model hash
    if write_build_information(model_cdict, meta_cdict):
        logger.info("information json file created")

    return model