예제 #1
0
def run(config_dict=None, db=None, config_id=None):
    if not config_dict:
        config_dict = get_config(config_id)
        bootstrap(config_dict)

    logs = get_logs_from_config_id(config_dict.general.id)

    if not logs["config"]:
        raise ValueError("Empty log file")
    metric_keys = set(["imagination_log_likelihood", "loss", "time_taken"])
    plot_dir = config_dict.plot.base_path
    for mode in ["train", "val"]:
        for key in logs[mode]:
            if key in metric_keys:
                plot(logs[mode][key], mode, key, plot_dir)
    if (USE_DATABASE):
        best_metric_logs = log_to_spreadsheet(logs)
        try:
            if (not db):
                db = Database(connect_to_firebase=False)
            db.update_job(job_id=config_dict.general.id,
                          project_id=PROJECT,
                          data_to_update={"status": "recorded"})
        except FileNotFoundError as f:
            print("Could not log results to journal")
        return best_metric_logs
    else:
        return None
예제 #2
0
def test_parser():
    """Method to test if the config parser can load the config file correctly"""
    config_name = "sample_config"
    config = get_config(config_name)
    set_logger(config)
    write_message_logs("torch version = {}".format(torch.__version__))
    assert config.general.id == config_name
예제 #3
0
def get_config_from_appid(app_id):
    config = get_config(read_cmd_args=False)
    log_file_path = config[LOG][FILE_PATH]
    logs_dir = "/".join(log_file_path.split("log.txt")[0].split("/")[:-2])
    log_file_path = Path(logs_dir, app_id, "log.txt")
    logs = parse_log_file(log_file_path)
    return logs[CONFIG][0]
예제 #4
0
def bootstrap_config(config_id, seed=-1):
    """Method    to generate the config (using config id) and set seeds"""
    config = get_config(config_id, experiment_id=0)
    if seed > 0:
        set_seed(seed=seed)
    else:
        set_seed(seed=config.general.seed)
    return config
예제 #5
0
def bootstrap(config_id):
    config_dict = get_config(config_id=config_id)
    print(config_dict.log)
    set_logger(config_dict)
    write_message_logs("Starting Experiment at {}".format(
        time.asctime(time.localtime(time.time()))))
    write_message_logs("torch version = {}".format(torch.__version__))
    write_config_log(config_dict)
    set_seed(seed=config_dict.general.seed)
    return config_dict
예제 #6
0
def run(config_id):
    print("torch version = {}".format(torch.__version__))
    config_dict = get_config(config_id=config_id)
    set_seed(seed=config_dict.general.seed)
    module_name = "codes.data.loader.loaders"
    datatset = importlib.import_module(module_name).RolloutSequenceDataset(
        config=config_dict, mode="train")
    datatset.load_next_buffer()
    for idx in range(1):
        a = datatset.__getitem__(idx)[0][0]
        show_tensor_as_image((a * 255).numpy().transpose(1, 2, 0))
예제 #7
0
def run_multiple(config_dicts=None, db=None, config_ids=None):
    if not config_dicts:
        config_dicts = list(
            map(
                lambda config_id: get_config(config_id=config_id,
                                             should_make_dir=False),
                config_ids))
    list(map(lambda _dict: bootstrap(_dict), config_dicts))
    legend_list = list(map(lambda _dict: _dict.general.id, config_dicts))
    log_file_paths = list(map(lambda _dict: _dict.log.file_path, config_dicts))
    logs_list = list(map(parse_log_file, log_file_paths))
    metric_keys = set(["imagination_log_likelihood"])
    for mode in ["val"]:
        for key in logs_list[0][mode]:
            if key in metric_keys:
                plot_multiple(map(lambda x: x[mode][key], logs_list),
                              legend_list, mode, key)
예제 #8
0
    """
    Dynamically load both encoder and decoder
    :param config:
    :return:
    """
    model_config = prepare_config_for_model(config)
    encoder_model_name = model_config.encoder.name
    encoder_module = _import_module(encoder_model_name)
    decoder_model_name = model_config.decoder.name
    decoder_module = _import_module(decoder_model_name)
    return encoder_module(model_config), decoder_module(model_config)


def _import_module(full_module_name):
    """
    Import className from python file
    https://stackoverflow.com/a/8790232
    :param full_module_name: full resolvable module name
    :return: module
    """
    path, name = full_module_name.rsplit('.', 1)
    base_module = importlib.import_module(path)
    module = getattr(base_module, name)
    return module


if __name__ == "__main__":
    config = get_config()
    model = choose_model(config, 10)
    print(model)
예제 #9
0
logging.basicConfig(level=logging.INFO,
                    format='%(asctime)s - %(levelname)s - %(message)s')

from sacred import Experiment
ex = Experiment('dummy_name')


@ex.main
def start(_config, _run):
    config = Dict(_config)
    set_seed(seed=config.seed)
    run_experiment(config, _run)


if __name__ == '__main__':
    config_id = argument_parser()
    print(config_id)
    config = get_config(config_id=config_id)
    ex.add_config(config)
    options = {}
    options['--name'] = 'exp_{}'.format(config_id)
    if config.logging.use_mongo:
        options['--mongo_db'] = '{}:{}:{}'.format(config.log.mongo_host,
                                                  config.log.mongo_port,
                                                  config.log.mongo_db)
    else:
        base_path = str(
            os.path.dirname(os.path.realpath(__file__)).split('/codes')[0])
        log_path = os.path.join(base_path, config.logging.dir)
        ex.observers.append(FileStorageObserver.create(log_path))
    ex.run(options=options)
예제 #10
0
def test_serialization():
    """Method to test if the config object is serializable"""
    config_name = "sample_config"
    config = get_config(config_name)
    set_logger(config)
    assert write_config_log(config) is None
예제 #11
0
                data.graphs[mode][indices],
                data.queries[mode][indices],
                indices,
                data.world_graph,
            )
        else:
            graphs = [data.graphs[mode][idx] for idx in indices]
            return graphs, data.queries[mode][
                indices], indices, data.world_graph

    def get_input_range(self, mode="train"):
        data = self.graphworld_list[self.current_graphworld_idx]
        indices = range(data.get_num_graphs(mode))
        if type(data.graphs[mode]) == np.array:
            return data.graphs[mode][indices], data.queries[mode][
                indices], indices
        else:
            graphs = [data.graphs[mode][idx] for idx in indices]
            return graphs, data.queries[mode][indices], indices


if __name__ == "__main__":
    config = get_config("sample_config")
    meta_mode = "train"
    mode = "train"
    metadata = TaskFamily(config, mode=meta_mode)
    target_fn = metadata.sample_task()
    graphs, queries, indices = metadata.sample_inputs(batch_size=32, mode=mode)
    for graph, query, idx in zip(graphs, queries, indices):
        print(graph, queries, target_fn(idx, mode))
예제 #12
0
def create_configs(config_id):
    base_config = get_config(config_id=config_id)
    current_id = 0
    # for general
    hyperparams_dict = {
        "model": {
            "optimiser": {
                "learning_rate": [0.1, 0.01, 0.001, 0.0001]
            },
            "embedding": {
                "dim": [50, 100, 150, 200, 250, 300]
            }
        }
    }

    if config_id == 'rn':
        # for bilstm
        hyperparams_dict.update({
            "model": {
                "rn": {
                    "g_theta_dim": [64, 128, 256],
                    "f_theta": {
                        "dim_1": [64, 128, 256, 512],
                        "dim_2": [64, 128, 256, 512]
                    }
                }
            }
        })

    if config_id == 'rn_tpr':
        hyperparams_dict.update({
            "model": {
                "rn": {
                    "g_theta_dim": [64, 128, 256],
                    "f_theta": {
                        "dim_1": [64, 128, 256, 512],
                        "dim_2": [64, 128, 256, 512]
                    }
                }
            }
        })

    if config_id == 'mac':
        hyperparams_dict.update({
            "model": {
                "rn": {
                    "g_theta_dim": [64, 128, 256],
                    "f_theta": {
                        "dim_1": [64, 128, 256, 512],
                        "dim_2": [64, 128, 256, 512]
                    }
                }
            }
        })

    if config_id == 'gat_clean':
        hyperparams_dict.update({
            "model": {
                "graph": {
                    "message_dim": [50, 100, 150, 200],
                    "num_message_rounds": [1, 2, 3, 4, 5]
                }
            }
        })

    path = os.path.dirname(os.path.realpath(__file__)).split('/codes')[0]
    target_dir = os.path.join(path, "config")

    for hyperparams in create_list_of_Hyperparams(hyperparams_dict):
        new_config = deepcopy(base_config)
        current_str_id = config_id + "_hp_" + str(current_id)
        new_config["general"]["id"] = current_str_id
        new_config["model"]["checkpoint"] = False
        # new_config["log"]["base_path"] = "/checkpoint/koustuvs/clutrr/"
        for hyperparam in hyperparams:
            setInDict(new_config, hyperparam.key_list, hyperparam.value)
        new_config_file = target_dir + "/{}.yaml".format(current_str_id)
        with open(new_config_file, "w") as f:
            f.write(
                yaml.dump(yaml.load(json.dumps(new_config)),
                          default_flow_style=False))
        current_id += 1
예제 #13
0
    ds.max_ents = max_ents
    logging.info("Processing words...")
    for data in datas:
        ds.preprocess(data)

    # save dictionary
    dictionary = {
        'word2id': ds.word2id,
        'id2word': ds.id2word,
        'target_word2id': ds.target_word2id,
        'target_id2word': ds.target_id2word,
        'max_ents': ds.max_ents,
        'max_vocab': ds.max_vocab,
        'max_entity_id': ds.max_entity_id,
        'entity_ids': ds.entity_ids,
        'dummy_entitiy': ds.dummy_entity,
        'entity_map': ds.entity_map
    }
    json.dump(dictionary, open(dictionary_file, 'w'))
    logging.info("Saved dictionary at {}".format(dictionary_file))


if __name__ == '__main__':
    # Generate a dictionary once and re-use it over again
    # We do this to resolve the issue of unknown elements in generalizability
    # experiments
    # Take the last training file which has the longest path and make a dictionary
    parent_dir = os.path.abspath(os.pardir).split('/codes')[0]
    config = get_config(config_id='gat_clean')
    generate_dictionary(config)
예제 #14
0
        },  # show links instead of texts
        {
            r'[ \t]*<[^<]*?/?>': u''
        },  # remove remaining tags
        {
            r'^\s+': u''
        }  # remove spaces at the beginning
    ]
    for rule in rules:
        for (k, v) in rule.items():
            regex = re.compile(k)
            text = regex.sub(v, text)
        text = text.rstrip()
        text = text.strip()
    return text.lower()


if __name__ == '__main__':
    config = get_config('7.dbp')
    ds = Data_Utility(config)
    ds.load()
    pdb.set_trace()
    dt = ds.get_dataloader(mode='train')
    for batch in dt:
        if min(batch.inp_lengths) <= 0:
            break
    dt = ds.get_dataloader(mode='test')
    for batch in dt:
        if min(batch.inp_lengths) <= 0:
            break
예제 #15
0
def get_logs_from_config_id(config_id):
    config = get_config(config_id=config_id, should_make_dir=False)
    bootstrap(config)
    log_file_path = config.log.file_path
    logs = parse_log_file(log_file_path=log_file_path)
    return logs