コード例 #1
0
    def __init__(
            self, directory, pattern, vocabs, src_vectorizer, tgt_vectorizer, last_turn_only=False,
            distribute=True, shuffle=True, record_keys=[]
    ):
        super().__init__()
        self.record_keys = record_keys
        self.src_vectorizer = src_vectorizer
        self.tgt_vectorizer = tgt_vectorizer
        self.pattern = pattern
        self.directory = directory
        self.vocab = vocabs
        self.samples = 0
        self.rank = 0
        self.world_size = 1
        self.shuffle = shuffle
        self.last_turn_only = last_turn_only
        self.distribute = distribute
        if torch.distributed.is_initialized() and distribute:
            self.rank = torch.distributed.get_rank()
            self.world_size = torch.distributed.get_world_size()

        if os.path.exists(f"{directory}/md.yml"):
            f = read_yaml(f"{directory}/md.yml")
            self.samples = f['num_samples']
        else:
            files = list(glob.glob(f"{directory}/{self.pattern}"))
            pg = create_progress_bar(len(files))
            for file in pg(files):
                with open(file) as rf:
                    for _ in rf:
                        self.samples += 1
            write_yaml({'num_samples': self.samples}, f"{directory}/md.yml")
コード例 #2
0
    def __init__(self, directory, pattern, vocabs, vectorizers, nctx):
        super().__init__()
        self.src_vectorizer = vectorizers['src']
        self.tgt_vectorizer = vectorizers['tgt']
        self.pattern = pattern
        self.nctx = nctx
        self.directory = directory
        self.vocab = vocabs
        self.samples = 0
        self.rank = 0
        self.world_size = 1
        if torch.distributed.is_initialized():
            self.rank = torch.distributed.get_rank()
            self.world_size = torch.distributed.get_world_size()

        if os.path.exists(f"{directory}/md.yml"):
            f = read_yaml(f"{directory}/md.yml")
            self.samples = f['num_samples']
        else:
            files = list(glob.glob(f"{directory}/{self.pattern}"))
            pg = create_progress_bar(len(files))
            for file in pg(files):
                with open(file) as rf:
                    for _ in rf:
                        self.samples += 1
            write_yaml({'num_samples': self.samples}, f"{directory}/md.yml")
コード例 #3
0
ファイル: main.py プロジェクト: Interactions-AI/odin
def _job_def(id_):
    try:
        job_loc = _get_job_file(id_, 'main.yml')
        object = read_yaml(job_loc)
        tasks = []
        for t in object['tasks']:
            if 'mounts' not in t:
                t['mounts'] = t.pop('mount')
            name = t['name']
            t['id'] = f'{id_}--{name}'
            task_def = TaskDefinition(**t)
            tasks.append(task_def)
        job_def = JobDefinition(tasks=tasks, location=job_loc, name=id_, id=id_, configs=_get_job_files(job_loc))
        return job_def
    except Exception as e:
        LOGGER.error(e)
        return None
コード例 #4
0
def test_read_yaml_strict():
    pytest.importorskip("yaml")
    with pytest.raises(IOError):
        read_yaml(os.path.join("not_there.yml"), strict=True)
コード例 #5
0
def test_read_yaml_given_default():
    pytest.importorskip("yaml")
    gold_default = "default"
    data = read_yaml(os.path.join("not_there.yml"), gold_default)
    assert data == gold_default
コード例 #6
0
def test_read_yaml_default_value():
    pytest.importorskip("yaml")
    gold_default = {}
    data = read_yaml(os.path.join(data_loc, "not_there.yml"))
    assert data == gold_default
コード例 #7
0
def test_read_yaml(gold_data):
    pytest.importorskip("yaml")
    data = read_yaml(os.path.join(data_loc, "test_yaml.yml"))
    assert data == gold_data
コード例 #8
0
def get_num_samples(sample_md):
    yml = read_yaml(sample_md)
    return yml['num_samples']