コード例 #1
0
def build_optimizer(cfg_solver: dict) -> optimizer_v2.OptimizerV2:
    """Build the optimizer.

    Args:
        cfg_solver: dict, solver subsection of config.

    Returns:
        optimizer_v2.OptimizerV2, tf.keras v2 optimizer.

    Raises:
        TypeError, optimizer not an OptimizerV2.
    """
    path = cfg_solver["optimizer"]["import"]
    params = cfg_solver["optimizer"].get("params", {})
    learning_rate = cfg_solver["optimizer"]["learning_rate"]

    if isinstance(learning_rate, dict):
        lr_cls = import_utils.import_obj_with_search_modules(
            learning_rate["import"], ["tensorflow.keras.optimizers.schedules"])
        lr = lr_cls(**learning_rate["params"])
        if not isinstance(lr, learning_rate_schedule.LearningRateSchedule):
            raise TypeError(
                f"import learning rate: {lr} is not a LearningRateSchedule")
    else:
        lr = learning_rate

    opt_cls = import_utils.import_obj_with_search_modules(
        path, ["tensorflow.keras.optimizers"], True)
    opt = opt_cls(learning_rate=lr, **params)

    if not isinstance(opt, optimizer_v2.OptimizerV2):
        raise TypeError(f"import optimizer: {opt} is not an OptimizerV2")

    return opt
コード例 #2
0
def test_import_obj_with_search_modules(python_path, search_modules,
                                        both_cases, expected):
    if expected:
        obj = import_utils.import_obj_with_search_modules(
            python_path, search_modules, both_cases)
        assert obj() == "test"
    else:
        with pytest.raises(ImportError):
            import_utils.import_obj_with_search_modules(
                python_path, search_modules, both_cases)
コード例 #3
0
def sequential_from_config(layers: List[dict], **kwargs) -> tf.keras.Model:
    """Build a sequential model from a list of layer specifications.
    Supports references to network_params computed inside Transformers by specifying
    {{variable name}}.

    Args:
        layers: list[dict], layer imports.

    Returns:
        tf.keras.Model, network.
    """
    layers = config._render_params(layers, kwargs)
    network = tf.keras.models.Sequential()
    for layer in layers:

        if "import" not in layer:
            raise KeyError(f"layer {layer} missing 'import' key")
        if not layer.keys() <= {"import", "params"}:
            unexpected_keys = set(layer.keys()).difference(
                {"import", "params"})
            raise KeyError(
                f"layer {layer} unexpected key(s): {unexpected_keys}")

        layer_cls = import_utils.import_obj_with_search_modules(
            layer["import"], search_modules=["tensorflow.keras.layers"])
        layer_params = layer.get("params", {})
        network.add(layer_cls(**layer_params))

    return network
コード例 #4
0
def build_network(cfg_model: dict, transform_params: dict) -> tf.keras.Model:
    """Build the network.

    Args:
        cfg_model: dict, model subsection of config.
        transform_params: dict, params from transformer.

    Returns:
        tf.keras.Model, network.

    Raises:
        TypeError, network not a tf.keras.Model.
    """
    path = cfg_model["network"]["import"]
    network_params = cfg_model["network"].get("params", {})

    net_func = import_utils.import_obj_with_search_modules(path)
    net = net_func(**network_params, **transform_params)

    if not isinstance(net, tf.keras.Model):
        raise TypeError(f"import network: {net} is not a tf.keras.Model")

    return net
コード例 #5
0
ファイル: dataset.py プロジェクト: elephantgrass7/barrage
    def __init__(
        self,
        artifact_dir: str,
        cfg_dataset: dict,
        records: Union[pd.DataFrame, core.Records],
        mode: core.RecordMode,
        batch_size: int,
    ):

        if not isinstance(mode, core.RecordMode):
            raise TypeError("mode must be type RecordMode")

        if isinstance(records, pd.DataFrame):
            records.reset_index(drop=True, inplace=True)
            self.records = records.to_dict(orient="records")
        elif all(isinstance(record, dict) for record in records):
            self.records = records
        else:
            raise TypeError(
                "record must be a list of dicts or pandas DataFrame")

        self.num_records = len(records)
        logger.info(f"Building {mode} dataset with {self.num_records} records")
        self.mode = mode
        self.batch_size = batch_size

        self.seed = cfg_dataset.get("seed")
        np.random.seed(self.seed)

        sample_count = cfg_dataset.get("sample_count")
        if self.mode == core.RecordMode.TRAIN and sample_count is not None:
            self._sample_inds = convert_sample_count_to_inds(
                [record[sample_count] for record in self.records])
        else:
            self._sample_inds = list(range(self.num_records))
        self.shuffle()

        logger.info(f"Creating record loader")
        loader_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["loader"]["import"], search_modules=SEARCH_MODULES)
        self.loader = loader_cls(mode=mode,
                                 params=cfg_dataset["loader"].get(
                                     "params", {}))
        if not isinstance(self.loader, RecordLoader):
            raise TypeError(
                f"loader {self.loader} is not of type RecordLoader")

        logger.info(f"Creating record transformer")
        transformer_cls = import_utils.import_obj_with_search_modules(
            cfg_dataset["transformer"]["import"],
            search_modules=SEARCH_MODULES)
        self.transformer = transformer_cls(
            mode=self.mode,
            loader=self.loader,
            params=cfg_dataset["transformer"].get("params", {}),
        )
        if not isinstance(self.transformer, RecordTransformer):
            raise TypeError(
                f"transformer {self.transformer} is not of type RecordTransformer"
            )

        dataset_dir = os.path.join(artifact_dir, "dataset")
        if self.mode == core.RecordMode.TRAIN:
            logger.info("Creating record augmentor")
            self.augmentor = RecordAugmentor(cfg_dataset["augmentor"])
            logger.info(
                f"Fitting transform: {self.transformer.__class__.__name__}")
            self.transformer.fit(copy.deepcopy(self.records))
            logger.info(
                f"Transformer network params: {self.transformer.network_params}"
            )
            logger.info("Saving transformer")
            self.transformer.save(dataset_dir)
        else:
            logger.info(
                f"Loading transform: {self.transformer.__class__.__name__}")
            self.transformer.load(dataset_dir)