def check_experiment_type_can_train(self, param_file):
        param_dict = pyhocon.ConfigFactory.parse_file(param_file)
        params = Params(replace_none(param_dict))
        model_class = concrete_models[params.pop("model_class")]
        # Tests will try to create root directories as we have /net/efs paths,
        # so we just remove the serialisation aspect here, alter the train/validation
        # paths to the dummy test ones and make sure we only do one epoch to
        # speed things up.
        params["model_serialization_prefix"] = None
        if len(params["train_files"]) > 1:
            params["train_files"] = [self.TRAIN_FILE, self.TRAIN_BACKGROUND]
            params["validation_files"] = [
                self.VALIDATION_FILE, self.VALIDATION_BACKGROUND
            ]
        else:
            params["train_files"] = [self.TRAIN_FILE]
            params["validation_files"] = [self.TRAIN_FILE]
        params["num_epochs"] = 1
        try:
            if params["embeddings"]["words"]["pretrained_file"]:
                params["embeddings"]["words"][
                    "pretrained_file"] = self.PRETRAINED_VECTORS_GZIP

        except KeyError:
            # No embedding/words field passed in the parameters,
            # so nothing to change.
            pass

        model = self.get_model(model_class, params)
        model.train()
Пример #2
0
def serve(port: int, param_file: str):
    # read in the Typesafe-style config file
    params = pyhocon.ConfigFactory.parse_file(param_file)
    params = Params(replace_none(params))
    retrieval_params = params.pop('retrieval')
    corpus_file = params.pop('corpus', None)
    num_neighbors = params.pop('num_neighbors', 10)

    global retrieval
    retrieval = VectorBasedRetrieval(retrieval_params)
    if corpus_file is not None:
        retrieval.read_background(corpus_file)
        retrieval.fit()
        retrieval.save_model()
    else:
        retrieval.load_model()

    # start the server on the specified port
    print("starting server")
    app.run(host='0.0.0.0')
Пример #3
0
 def __init__(self, params: Params):
     self.num_stacked_rnns = params.pop('num_stacked_rnns', 1)
     instance_type = params.pop('instance_type', "VerbSemanticsInstance")
     self.instance_type = concrete_instances[instance_type]
     super(VerbSemanticsModel, self).__init__(params)