예제 #1
0
파일: run.py 프로젝트: inyukwo1/rat-sql
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('mode',
                        help="preprocess/train/eval",
                        choices=["preprocess", "train", "eval"])
    parser.add_argument('exp_config_file', help="jsonnet file for experiments")
    parser.add_argument('--model_config_args',
                        help="optional overrides for model config args")
    parser.add_argument('--logdir', help="optional override for logdir")
    parser.add_argument('--ray', action='store_true')
    args = parser.parse_args()

    exp_config = json.loads(_jsonnet.evaluate_file(args.exp_config_file))
    model_config_file = exp_config["model_config"]
    if "model_config_args" in exp_config:
        model_config_args = exp_config["model_config_args"]
        if args.model_config_args is not None:
            model_config_args_json = _jsonnet.evaluate_snippet(
                "", args.model_config_args)
            model_config_args.update(json.loads(model_config_args_json))
        model_config_args = json.dumps(model_config_args)
    elif args.model_config_args is not None:
        model_config_args = _jsonnet.evaluate_snippet("",
                                                      args.model_config_args)
    else:
        model_config_args = None

    logdir = args.logdir or exp_config["logdir"]

    if args.mode == "preprocess":
        preprocess_config = PreprocessConfig(model_config_file,
                                             model_config_args)
        preprocess.main(preprocess_config)
    elif args.mode == "train":
        mp.spawn(train_model,
                 nprocs=8,
                 args=(model_config_file, model_config_args, logdir, args.ray))

    elif args.mode == "eval":
        for step in exp_config["eval_steps"]:
            infer_output_path = f"{exp_config['eval_output']}/{exp_config['eval_name']}-step{step}.infer"
            infer_config = InferConfig(
                model_config_file,
                model_config_args,
                logdir,
                exp_config["eval_section"],
                exp_config["eval_beam_size"],
                infer_output_path,
                step,
                use_heuristic=exp_config["eval_use_heuristic"])
            infer.main(infer_config)

            eval_output_path = f"{exp_config['eval_output']}/{exp_config['eval_name']}-step{step}.eval"
            eval_config = EvalConfig(model_config_file, model_config_args,
                                     logdir, exp_config["eval_section"],
                                     infer_output_path, eval_output_path)
            eval.main(eval_config)

            res_json = json.load(open(eval_output_path))
            print(step, res_json['total_scores']['all']['exact'])
예제 #2
0
    def process(code):
        code = preamble + "\n" + code

        if import_callback:
            r = _jsonnet.evaluate_snippet(
              "snippet", code,
              import_callback=import_callback
            )
        else:
            r = _jsonnet.evaluate_snippet("snippet", code)

        return json.loads(r)
예제 #3
0
파일: config.py 프로젝트: mbencherif/RSPNet
def get_config(args: Args) -> ConfigTree:
    def import_callback(dir, rel):
        arg_match = arg_regex.match(rel)
        if arg_match is not None:
            full_path = rel
            index = int(arg_match.group(1))
            content = ext_config_template(args.ext_config[index])
        else:
            if rel == '__base_config__':
                rel = Path(args.config)
            elif rel == '__addition_config__':
                rel = Path(args.config).with_name('addition.libsonnet')
            else:
                rel = Path(rel)
            full_path = dir / rel
            full_path = str(full_path)
            with open(full_path) as f:
                content = f.read()
        return full_path, content

    json_str = evaluate_snippet(
        '__composed_config__',
        config_snippet(len(args.ext_config)),
        import_callback=import_callback,
    )

    json_obj = json.loads(json_str)
    cfg = ConfigFactory.from_dict(json_obj)

    logger.info(f'Config = \n{HOCONConverter.to_hocon(cfg)}')

    return cfg
예제 #4
0
    def test_regex_matches_are_initialized_correctly(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
                self.linear_2 = torch.nn.Linear(10, 5)
                self.conv = torch.nn.Conv1d(5, 5, 5)

            def forward(self, inputs):  # pylint: disable=arguments-differ
                pass

        # Make sure we handle regexes properly
        json_params = """{"initializer": [
        ["conv", {"type": "constant", "val": 5}],
        ["funky_na.*bi", {"type": "constant", "val": 7}]
        ]}
        """
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
        initializers = InitializerApplicator.from_params(params['initializer'])
        model = Net()
        initializers(model)

        for parameter in model.conv.parameters():
            assert torch.equal(parameter.data,
                               torch.ones(parameter.size()) * 5)

        parameter = model.linear_1_with_funky_name.bias
        assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
예제 #5
0
def parse_overrides(serialized_overrides: str) -> Dict[str, Any]:
    if serialized_overrides:
        ext_vars = _environment_variables()

        return unflatten(json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars)))
    else:
        return {}
예제 #6
0
 def _eval(self):
     '''
     Evaluate template using the Jsonnet library.
     '''
     data = _jsonnet.evaluate_snippet(
         'template', self.body, native_callbacks=self.native_callbacks)
     return json.loads(data)
예제 #7
0
파일: cli.py 프로젝트: chikin-4x/jsonnet
def render_jsonnet(jsonnet_path=None,
                   functions_path=None,
                   tla_str=None,
                   ext_var=None):
    account_alias = 'fakealias'
    env = 'prod'
    if jsonnet_path:
        jsonnet_path = pathlib.Path(jsonnet_path)
    else:
        jsonnet_path = pathlib.Path('./manifest.jsonnet')
    if functions_path:
        jsonnet_functions_path = pathlib.Path(functions_path)
    else:
        jsonnet_functions_path = jsonnet_path.parent / 'jsonnet_functions.py'
    func_dict = None
    ext_vars = {}
    if ext_vars == {}:
        ext_vars = {
            'environment': env,
            "account_alias": account_alias,
            "nonprod_account_alias": account_alias
        }
    func_dict = get_native_dict(get_file_contents(jsonnet_functions_path))
    jsonnet_contents = get_file_contents(jsonnet_path)
    manifest_json = _jsonnet.evaluate_snippet(
        'manifest.jsonnet',
        jsonnet_contents,
        ext_vars={e.split('=')[0]: e.split('=')[1]
                  for e in ext_var},
        tla_vars={e.split('=')[0]: e.split('=')[1]
                  for e in tla_str},
        native_callbacks=func_dict)
    return json.loads(json.dumps(manifest_json))
예제 #8
0
        def train_func(config, reporter):
            logger.debug(
                f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")

            for package_name in getattr(args, "include_package", ()):
                import_submodules(package_name)

            search_space = HyperparameterSearch(**config)
            sample = search_space.sample()
            for k, v in sample.items():
                config[k] = str(v)

            params_dict = json.loads(
                _jsonnet.evaluate_snippet("config",
                                          parameter_file_snippet,
                                          tla_codes={},
                                          ext_vars=config))
            if args.num_gpus == 0:
                logger.warning(f"No GPU specified, using CPU.")
                params_dict["trainer"]["cuda_device"] = -1

            if args.cpus_per_trial > 0:
                torch.set_num_threads(args.cpus_per_trial)

            params = Params(params_dict)

            logger.debug(f"AllenNLP Configuration: {params.as_dict()}")

            train_model(params=params, serialization_dir="trial")

            reporter(done=True)
예제 #9
0
파일: config.py 프로젝트: huww98/ml-lab
def get_config(args) -> Dict:
    def import_callback(dir, rel):
        arg_match = arg_regex.match(rel)
        if arg_match is not None:
            full_path = rel
            index = int(arg_match.group(1))
            content = ext_config_template(args.ext_config[index])
        else:
            if rel == '__base_config__':
                rel = Path(args.config)
            elif rel == '__addition_config__':
                rel = Path(args.config).with_name('addition.libsonnet')
            else:
                rel = Path(rel)
            full_path = dir / rel
            full_path = str(full_path)
            with open(full_path) as f:
                content = f.read()
        return full_path, content

    json_str = evaluate_snippet(
        '__composed_config__',
        config_snippet(len(args.ext_config)),
        import_callback=import_callback,
    )

    json_obj = json.loads(json_str)

    return json_obj
예제 #10
0
 def test_evaluate_snippet(self):
     json_str = _jsonnet.evaluate_snippet(
         "snippet",
         self.input_snippet,
         import_callback=import_callback,
         native_callbacks=native_callbacks,
     )
     self.assertEqual(json_str, self.expected_str)
예제 #11
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('mode', help="preprocess/train/eval")
    parser.add_argument('exp_config_file', help="jsonnet file for experiments")
    parser.add_argument('--model_config_args',
                        help="optional overrides for model config args")
    parser.add_argument('--logdir', help="optional override for logdir")
    args = parser.parse_args()

    exp_config = json.loads(_jsonnet.evaluate_file(args.exp_config_file))
    model_config_file = exp_config["model_config"]
    if "model_config_args" in exp_config:
        model_config_args = exp_config["model_config_args"]
        if args.model_config_args is not None:
            model_config_args_json = _jsonnet.evaluate_snippet(
                "", args.model_config_args)
            model_config_args.update(json.loads(model_config_args_json))
        model_config_args = json.dumps(model_config_args)
    elif args.model_config_args is not None:
        model_config_args = _jsonnet.evaluate_snippet("",
                                                      args.model_config_args)
    else:
        model_config_args = None

    log_dir = args.logdir or exp_config["logdir"]

    if args.mode == "preprocess":
        preprocess_config = PreprocessConfig(model_config_file,
                                             model_config_args)
        preprocess.main(preprocess_config)
    elif args.mode == "train":
        train_config = TrainConfig(model_config_file, model_config_args,
                                   log_dir)
        train.main(train_config)
    elif args.mode == "infer":
        infer_output_path = f"{exp_config['test_output']}/{exp_config['test_name']}.jsonl"
        infer_config = InferConfig(
            model_config_file,
            model_config_args,
            log_dir,
            exp_config["test_section"],
            infer_output_path,
        )
        infer.main(infer_config)
    else:
        raise ValueError(f"command {args.mode} is not supported")
예제 #12
0
파일: utils.py 프로젝트: Lord-Haji/pyMon
def load_js_obj_literal(j):
    """Terrible hack."""
    j = j[j.index('{'):]
    j = j.replace('\n', '').replace('\t', '')
    j = j.replace(';', '')
    j = re.sub(r'//.*?{', r'{', j)
    result = json.loads(_jsonnet.evaluate_snippet('snippet', j))
    return result
예제 #13
0
def load_config_from_json(config_file: str) -> Dict:
    # load configuration
    if not os.path.isfile(config_file):
        raise ValueError("given configuration file doesn't exist")
    with open(config_file, "r") as fio:
        config = fio.read()
        config = json.loads(_jsonnet.evaluate_snippet("", config))
    return config
예제 #14
0
    def __load_jsonnet(self, params_path):
        with open(params_path) as f:
            jsonnet_str = f.read()

        json_str = _jsonnet.evaluate_snippet(
            "snippet", jsonnet_str)

        return json.loads(json_str)
예제 #15
0
파일: jsonnet.py 프로젝트: brettviren/moo
def loads(jtext, paths=(), **kwds):
    '''
    Load Jsonnet text
    '''
    paths = clean_paths(paths)
    ic = ImportCallback(paths)
    text = evaluate_snippet("<stdin>", jtext, import_callback=ic, **kwds)
    return json.loads(text)
예제 #16
0
 def test_evaluate_snippet(self):
     json_str = _jsonnet.evaluate_snippet(
         "snippet",
         self.input_snippet,
         import_callback=import_callback,
         native_callbacks=native_callbacks,
     )
     self.assertEqual(json_str, self.expected_str)
예제 #17
0
def train(options):
    config_str = prepare_config(options)
    config = json.loads(_jsonnet.evaluate_snippet("snippet", config_str))
    # The override flag in allennlp was finicky so I used a temporary file hack
    with open(TMP_FILENAME, "w") as file:
        json.dump(config, file, indent=2)
    serialization_dir = make_serialization_dirname(options)
    cmd = TRAIN_CMD.format(directory=serialization_dir, config=TMP_FILENAME)
    os.system(cmd)
예제 #18
0
def parse_overrides(serialized_overrides):
    if serialized_overrides:
        ext_vars = dict(os.environ)
        return unflatten(
            json.loads(
                evaluate_snippet(u"", serialized_overrides,
                                 ext_vars=ext_vars)))
    else:
        return {}
예제 #19
0
 def render_jsonnet(self, name, s, object_pairs_hook=OrderedDict):
     s = _jsonnet.evaluate_snippet(name,
                                   s,
                                   import_callback=self.import_callback,
                                   native_callbacks=self.native_callbacks,
                                   ext_vars={
                                       'VERIFY_NAMESPACES':
                                       '1' if self.verify_namespace else '0'
                                   })
     return json.loads(s, object_pairs_hook=object_pairs_hook)
예제 #20
0
파일: params.py 프로젝트: himkt/allennlp
def parse_overrides(
        serialized_overrides: str,
        ext_vars: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
    if serialized_overrides:
        ext_vars = {**_environment_variables(), **(ext_vars or {})}

        return json.loads(
            evaluate_snippet("", serialized_overrides, ext_vars=ext_vars))
    else:
        return {}
예제 #21
0
def parse_overrides(serialized_overrides: str) -> Dict[str, Any]:
    if serialized_overrides:
        ext_vars = dict(os.environ)
        return unflatten(
            json.loads(
                _jsonnet.evaluate_snippet("",
                                          serialized_overrides,
                                          ext_vars=ext_vars)))
    else:
        return {}
예제 #22
0
def jsonnet_loads(jsonnet_str, ext_vars=None):
    """
    Parses jsonnet string into json
    :param jsonnet_str: Jsonnet function
    :param ext_vars: External vars that can be passed as {'SOME_PARAM': 'AI2'} and used in the jsonnet as {name: std.extVar("SOME_PARAM")}
    :return:
    """
    json_parse = json.loads(
        _jsonnet.evaluate_snippet("snippet", jsonnet_str, ext_vars=ext_vars))

    return json_parse
예제 #23
0
def getBricksFromJsFileContent(jsContent):
    startIndex = jsContent.index('{')
    trimmedJsContent = jsContent[startIndex:].strip().strip(';').replace(
        '`', "'")
    jsonString = _jsonnet.evaluate_snippet('snippet', trimmedJsContent)
    jsonBricks = json.loads(jsonString)

    bricks = []
    for brick in jsonBricks:
        bricks.append(brick)
    return bricks
예제 #24
0
    def get_config(self, spread_feed: dict):
        """Reads the JSON config file from disk and returns it as a Python object.

        Returns:
            Current configuration as a `dict` or `list` object.
        """
        assert (isinstance(spread_feed, dict))

        mtime = os.path.getmtime(self.filename)

        # If the modification time has not changed since the last time we have read the file,
        # we return the last content without opening and parsing it. It saves us around ~ 30ms.
        #
        # Ultimately something like `watchdog` (<https://pythonhosted.org/watchdog/index.html>)
        # should be used to watch the filesystem changes asynchronously.
        if self._config is not None and self._mtime is not None:
            if mtime == self._mtime and spread_feed == self._spread_feed:
                return self._config

        with open(self.filename) as data_file:
            content_file = data_file.read()
            content_config = _jsonnet.evaluate_snippet(
                "snippet",
                content_file,
                ext_vars={},
                import_callback=self._spread_feed_import_callback(spread_feed))
            result = json.loads(content_config)

            # Report if file has been newly loaded or reloaded
            checksum_file = zlib.crc32(content_file.encode('utf-8'))
            checksum_config = zlib.crc32(content_config.encode('utf-8'))
            if self._checksum_file is None:
                self.logger.info(
                    f"Loaded configuration from '{self.filename}'")
                self.logger.debug(f"Config file is: " +
                                  json.dumps(result, indent=4))
            elif self._checksum_file != checksum_file:
                self.logger.info(
                    f"Reloaded configuration from '{self.filename}'")
                self.logger.debug(f"Reloaded config file is: " +
                                  json.dumps(result, indent=4))
            elif self._checksum_config != checksum_config:
                self.logger.debug(
                    f"Parsed configuration from '{self.filename}'")
                self.logger.debug(f"Parsed config file is: " +
                                  json.dumps(result, indent=4))

            self._checksum_file = checksum_file
            self._checksum_config = checksum_config
            self._config = result
            self._mtime = mtime
            self._spread_feed = spread_feed

            return result
예제 #25
0
파일: config.py 프로젝트: ngs-mstb/micgent
    def __call__(self, vars={}):
        import _jsonnet
        import json

        all_vars = self.vars_default.copy()
        all_vars.update(vars)
        all_vars = _jsonnet_vars_convert(all_vars)
        # self.config_file here is needed only to annotate stacktraces in case of errors
        return json.loads(
            _jsonnet.evaluate_snippet(self.config_file,
                                      self.config_str,
                                      ext_vars=all_vars))
예제 #26
0
    def render_jsonnet(self, manifeststr, tla_codes=None):
        try:
            json_str = _jsonnet.evaluate_snippet("snippet", manifeststr,
                                                 import_callback=self.import_callback,
                                                 native_callbacks=filters.jsonnet_callbacks(), tla_codes=tla_codes,
                                                 gc_min_objects=9999999, gc_growth_trigger=9999999)

        except RuntimeError as e:
            print "tla_codes: %s" % (str(tla_codes))
            print "\n".join(["%s %s" % (i, line) for i, line in
                             enumerate([l for l in manifeststr.split("\n") if re.match(r"^ *#", l) is None])])
            raise e
        return json.loads(json_str)
def process_cm_data(data, ext_libs=[], user_args={}):
    """Processes data field from jsonnet config map.

    Iterates through jsonnet files in configMap (.libsonnet files first)
    and generates json data.


    Args:
        data (dict): Data from config map labeled as jsonnet code.
        ext_libs (:obj:`list of str`, optional): List of paths to
            external jsonnet libs.
        user_args (:obj:`dict`, optional): Keyword arguments to jsonnet build function.

    Returns:
        list of (str, dict): Generated json data.

    Raises:
        JsonnetConfigMapError: Raised if jsonnet evaluation fails.
    """
    libsonnet_folder = "./libsonnets"
    jsons = []

    # sort by extension: .libsonnet fields first, .jsonnet second
    for dataKey in sorted(data.keys(),
                          key=lambda x: x.split(".")[1],
                          reverse=True):

        _, extension = os.path.splitext(dataKey)
        if extension == ".libsonnet":
            utils.save_text_to_file(libsonnet_folder, dataKey, data[dataKey])
            continue

        try:
            jsonnet_code = data[dataKey]
            json_ = _jsonnet.evaluate_snippet(dataKey,
                                              jsonnet_code,
                                              jpathdir=ext_libs,
                                              **user_args)
        except RuntimeError as e:
            log.error(f"{dataKey} is not a valid jsonnet, raised error: {e}")
            if os.path.exists(libsonnet_folder):
                utils.remove_folder(libsonnet_folder)
            raise JsonnetConfigMapError
        else:
            json_filename = utils.replace_extension(dataKey, "json")
            jsons.append((json_filename, json.loads(json_)))

    if os.path.exists(libsonnet_folder):
        utils.remove_folder(libsonnet_folder)

    return jsons
예제 #28
0
    def render_jsonnet(self, manifeststr, tla_codes=None):
        try:
            json_str = _jsonnet.evaluate_snippet(
                "snippet", manifeststr, import_callback=self.import_callback,
                native_callbacks=filters.jsonnet_callbacks(), tla_codes=tla_codes)

        except RuntimeError as e:
            print("tla_codes: %s" % (str(tla_codes)))
            print("\n".join([
                "%s %s" % (i, line) for i, line in enumerate(
                    [l for l in manifeststr.split("\n") if re.match(r"^ *#", l) is None])
            ]))
            raise e
        return json.loads(json_str)
def evaluate_jsonnet_build_annotations(annotations):
    """Evaluates jsonnet build annotations.

    Args:
        annotations (dict of str: str):  Annotations from jsonnet configmap.

    Returns:
        dict: Valid jsonnet build arguments with evaluated values.
    """
    evaluated_args = {}
    for key, value in annotations.items():
        try:
            evaluated_arg = ast.literal_eval(value)
            _jsonnet.evaluate_snippet("dummy", "{}", **{key: evaluated_arg})
            evaluated_args[key] = evaluated_arg
        except TypeError as e:
            log.error(
                f"Build argument from annotations {key} is invalid, error: {e}"
            )
        except json.decoder.JSONDecodeError as e:
            log.error(
                f"Evaluation of build argument {key} from annotations failed,"
                f" error: {e}")
    return evaluated_args
예제 #30
0
        def train_func(config, reporter):
            logger.debug(f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")

            for package_name in getattr(run_args, "include_package", ()):
                import_submodules(package_name)

            run_parameters = {k: json.dumps(v) for k, v in config.items()}
            search_config = SEARCH_ENVIRONMENTS[default_args.search_config]
            search_space = HyperparameterSearch(**search_config)
            sample = search_space.sample()
            logger.info(f"Hyperparameter Configuration: {sample}")

            for k, v in sample.items():
                sample[k] = str(v)
                os.environ[k] = str(v)

            params_dict = json.loads(
                _jsonnet.evaluate_snippet(
                    "config", parameter_file_snippet, tla_codes=run_parameters, ext_vars=sample
                )
            )
            if default_args.num_gpus == 0:
                logger.warning(f"No GPU specified, using CPU.")
                params_dict["trainer"]["cuda_device"] = -1

            # Make sure path is absolute (as Ray workers do not use the same working dir)
            train_data_path = params_dict["train_data_path"]
            validation_data_path = params_dict.get("validation_data_path")

            # if not os.path.isabs(train_data_path) and not is_s3_url(train_data_path):
            #     params_dict["train_data_path"] = os.path.abspath(
            #         os.path.join(default_args.cwd, train_data_path)
            #     )

            # if validation_data_path and not os.path.isabs(validation_data_path) and not is_s3_url(validation_data_path):
            #     params_dict["validation_data_path"] = os.path.abspath(
            #         os.path.join(default_args.cwd, validation_data_path)
            #     )

            params = Params(params_dict)

            logger.debug(f"AllenNLP Configuration: {params.as_dict()}")

            train_model(params=params, serialization_dir="./trial/")

            reporter(done=True)
예제 #31
0
    def test_regex_match_prevention_prevents_and_overrides(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super().__init__()
                self.linear_1 = torch.nn.Linear(5, 10)
                self.linear_2 = torch.nn.Linear(10, 5)
                # typical actual usage: modules loaded from allenlp.model.load(..)
                self.linear_3_transfer = torch.nn.Linear(5, 10)
                self.linear_4_transfer = torch.nn.Linear(10, 5)
                self.pretrained_conv = torch.nn.Conv1d(5, 5, 5)

            def forward(self, inputs):
                pass

        json_params = """{"initializer": {
        "regexes": [
            [".*linear.*", {"type": "constant", "val": 10}],
            [".*conv.*", {"type": "constant", "val": 10}]
            ],
        "prevent_regexes": [".*_transfer.*", ".*pretrained.*"]
        }}
        """
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
        initializers = InitializerApplicator.from_params(
            params=params["initializer"])
        model = Net()
        initializers(model)

        for module in [model.linear_1, model.linear_2]:
            for parameter in module.parameters():
                assert torch.equal(parameter.data,
                                   torch.ones(parameter.size()) * 10)

        transfered_modules = [
            model.linear_3_transfer,
            model.linear_4_transfer,
            model.pretrained_conv,
        ]

        for module in transfered_modules:
            for parameter in module.parameters():
                assert not torch.equal(parameter.data,
                                       torch.ones(parameter.size()) * 10)
예제 #32
0
        def train_func(config, reporter):
            logger.debug(
                f"CUDA_VISIBLE_DEVICES: {os.environ['CUDA_VISIBLE_DEVICES']}")

            for package_name in getattr(run_args, "include_package", ()):
                import_submodules(package_name)

            run_parameters = {k: json.dumps(v) for k, v in config.items()}

            file_dict = json.loads(
                _jsonnet.evaluate_snippet("config",
                                          parameter_file_snippet,
                                          tla_codes=run_parameters))
            if default_args.num_gpus == 0:
                logger.warning(f"No GPU specified, using CPU.")
                file_dict["trainer"]["cuda_device"] = -1

            overrides_dict = parse_overrides(run_args.overrides)

            params_dict = with_fallback(preferred=overrides_dict,
                                        fallback=file_dict)

            # Make sure path is absolute (as Ray workers do not use the same working dir)
            train_data_path = params_dict["train_data_path"]
            validation_data_path = params_dict.get("validation_data_path")

            if not os.path.isabs(train_data_path):
                params_dict["train_data_path"] = os.path.abspath(
                    os.path.join(default_args.cwd, train_data_path))

            if validation_data_path and not os.path.isabs(
                    validation_data_path):
                params_dict["validation_data_path"] = os.path.abspath(
                    os.path.join(default_args.cwd, validation_data_path))

            params = Params(params_dict)

            logger.debug(f"AllenNLP Configuration: {params.as_dict()}")

            train_model(params=params, serialization_dir="./trial/")

            reporter(done=True)
예제 #33
0
def json_merge_patch(target_json: str, patch_json: str) -> str:
    """Merge json objects according to JSON merge patch spec: https://tools.ietf.org/html/rfc7396.

    Takes a target json string, and a patch json string and applies the patch json to the target
    json according to "JSON Merge Patch" (defined by https://tools.ietf.org/html/rfc7396).

    Args:
        target_json: the json to be overwritten by the patch json.
        patch_json: the json used to overwrite the target json.

    Returns:
        json str after applying the patch json to the target json using "JSON Merge Patch" method.

    """
    merged: str = """local target = {target_json};
                     local patch = {patch_json};
                     std.mergePatch(target, patch)""".format(
        target_json=target_json, patch_json=patch_json
    )
    return _jsonnet.evaluate_snippet("snippet", merged)
    def get_config(self, spread_feed: dict):
        """Reads the JSON config file from disk and returns it as a Python object.

        Returns:
            Current configuration as a `dict` or `list` object.
        """
        assert(isinstance(spread_feed, dict))

        mtime = os.path.getmtime(self.filename)

        # If the modification time has not change since the last time we have read the file,
        # we return the last content without opening and parsing it. It saves us around ~ 30ms.
        #
        # Ultimately something like `watchdog` (<https://pythonhosted.org/watchdog/index.html>)
        # should be used to watch the filesystem changes asynchronously.
        if self._config is not None and self._mtime is not None:
            if mtime == self._mtime and spread_feed == self._spread_feed:
                return self._config

        with open(self.filename) as data_file:
            content_file = data_file.read()
            content_config = _jsonnet.evaluate_snippet("snippet", content_file, ext_vars={},
                                                       import_callback=self._spread_feed_import_callback(spread_feed))
            result = json.loads(content_config)

            # Report if file has been newly loaded or reloaded
            checksum = zlib.crc32(content_config.encode('utf-8'))
            if self._checksum is None:
                self.logger.info(f"Loaded configuration from '{self.filename}'")
                self.logger.debug(f"Config file is: " + json.dumps(result, indent=4))
            elif self._checksum != checksum:
                self.logger.info(f"Reloaded configuration from '{self.filename}'")
                self.logger.debug(f"Reloaded config file is: " + json.dumps(result, indent=4))

            self._checksum = checksum
            self._config = result
            self._mtime = mtime
            self._spread_feed = spread_feed

            return result
예제 #35
0
def parse_overrides(serialized_overrides: str) -> Dict[str, Any]:
    if serialized_overrides:
        ext_vars = dict(os.environ)
        return unflatten(json.loads(evaluate_snippet("", serialized_overrides, ext_vars=ext_vars)))
    else:
        return {}
예제 #36
0
import _jsonnet

if len(sys.argv) != 2:
    raise Exception("Usage: <snippet>")

#  Returns content if worked, None if file not found, or throws an exception
def try_path(dir, rel):
    if not rel:
        raise RuntimeError('Got invalid filename (empty string).')
    if rel[0] == '/':
        full_path = rel
    else:
        full_path = dir + rel
    if full_path[-1] == '/':
        raise RuntimeError('Attempted to import a directory')

    if not os.path.isfile(full_path):
        return full_path, None
    with open(full_path) as f:
        return full_path, f.read()


def import_callback(dir, rel):
    full_path, content = try_path(dir, rel)
    if content:
        return full_path, content
    raise RuntimeError('File not found')

sys.stdout.write(_jsonnet.evaluate_snippet("snippet", sys.argv[1], import_callback=import_callback))
예제 #37
0
    if full_path[-1] == '/':
        raise RuntimeError('Attempted to import a directory')

    if not os.path.isfile(full_path):
        return full_path, None
    with open(full_path) as f:
        return full_path, f.read()


def import_callback(dir, rel):
    full_path, content = try_path(dir, rel)
    if content:
        return full_path, content
    raise RuntimeError('File not found')

# Test native extensions
def concat(a, b):
    return a + b

native_callbacks = {
  'concat': (('a', 'b'), concat),
}

json_str = _jsonnet.evaluate_snippet(
    "snippet",
    sys.argv[1],
    import_callback=import_callback,
    native_callbacks=native_callbacks,
)
sys.stdout.write(json_str)
예제 #38
0
# Copyright 2014 Google Inc. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import sys
import _jsonnet

if len(sys.argv) != 2:
    raise Exception("Usage: <snippet>")

sys.stdout.write(_jsonnet.evaluate_snippet("snippet", sys.argv[1]))
예제 #39
0
    def run(self, tmp=None, task_vars=None):
        ''' handler for template operations '''
        if task_vars is None:
            task_vars = dict()

        result = super(ActionModule, self).run(tmp, task_vars)

        source = self._task.args.get('src', None)
        dest   = self._task.args.get('dest', None)
        faf    = self._task.first_available_file
        force  = boolean(self._task.args.get('force', True))
        state  = self._task.args.get('state', None)

        if state is not None:
            result['failed'] = True
            result['msg'] = "'state' cannot be specified on a template"
            return result
        elif (source is None and faf is not None) or dest is None:
            result['failed'] = True
            result['msg'] = "src and dest are required"
            return result

        if faf:
            source = self._get_first_available_file(faf, task_vars.get('_original_file', None, 'templates'))
            if source is None:
                result['failed'] = True
                result['msg'] = "could not find src in first_available_file list"
                return result
        else:
            if self._task._role is not None:
                source = self._loader.path_dwim_relative(self._task._role._role_path, 'templates', source)
            else:
                source = self._loader.path_dwim_relative(self._loader.get_basedir(), 'templates', source)

        # Expand any user home dir specification
        dest = self._remote_expand_user(dest)

        directory_prepended = False
        if dest.endswith(os.sep):
            directory_prepended = True
            base = os.path.basename(source)
            dest = os.path.join(dest, base)

        # template the source data locally & get ready to transfer
        try:
            with open(source, 'r') as f:
                template_data = to_unicode(f.read())

            temp_vars = task_vars.copy()
            if 'hostvars' in temp_vars['vars']:
              del temp_vars['vars']['hostvars']

            # Create a new searchpath list to assign to the templar environment's file
            # loader, so that it knows about the other paths to find template files
            searchpath = [self._loader._basedir, os.path.dirname(source)]
            if self._task._role is not None:
                if C.DEFAULT_ROLES_PATH:
                    searchpath[:0] = C.DEFAULT_ROLES_PATH
                searchpath.insert(1, self._task._role._role_path)

            resultant = _jsonnet.evaluate_snippet(
                source,
                template_data,
                tla_codes={'cfg': json.dumps(temp_vars['vars'])},
                import_callback=import_callback(searchpath),
            )
        except Exception as e:
            result['failed'] = True
            result['msg'] = type(e).__name__ + ": " + str(e)
            return result

        remote_user = task_vars.get('ansible_ssh_user') or self._play_context.remote_user
        if not tmp:
            tmp = self._make_tmp_path(remote_user)
            self._cleanup_remote_tmp = True

        local_checksum = checksum_s(resultant)
        remote_checksum = self.get_checksum(dest, task_vars, not directory_prepended, source=source, tmp=tmp)
        if isinstance(remote_checksum, dict):
            # Error from remote_checksum is a dict.  Valid return is a str
            result.update(remote_checksum)
            return result

        diff = {}
        new_module_args = self._task.args.copy()

        if (remote_checksum == '1') or (force and local_checksum != remote_checksum):

            result['changed'] = True
            # if showing diffs, we need to get the remote value
            if self._play_context.diff:
                diff = self._get_diff_data(dest, resultant, task_vars, source_file=False)

            if not self._play_context.check_mode: # do actual work thorugh copy
                xfered = self._transfer_data(self._connection._shell.join_path(tmp, 'source'), resultant)

                # fix file permissions when the copy is done as a different user
                self._fixup_perms(tmp, remote_user, recursive=True)

                # run the copy module
                new_module_args.update(
                   dict(
                       src=xfered,
                       dest=dest,
                       original_basename=os.path.basename(source),
                       follow=True,
                    ),
                )
                result.update(self._execute_module(module_name='copy', module_args=new_module_args, task_vars=task_vars, tmp=tmp, delete_remote_tmp=False))

            if result.get('changed', False) and self._play_context.diff:
                result['diff'] = diff

        else:
            # when running the file module based on the template data, we do
            # not want the source filename (the name of the template) to be used,
            # since this would mess up links, so we clear the src param and tell
            # the module to follow links.  When doing that, we have to set
            # original_basename to the template just in case the dest is
            # a directory.
            new_module_args.update(
                dict(
                    src=None,
                    original_basename=os.path.basename(source),
                    follow=True,
                ),
            )
            result.update(self._execute_module(module_name='file', module_args=new_module_args, task_vars=task_vars, tmp=tmp, delete_remote_tmp=False))

        self._remove_tmp_path(tmp)

        return result