def __init__( self, directory, pattern, vocabs, src_vectorizer, tgt_vectorizer, last_turn_only=False, distribute=True, shuffle=True, record_keys=[] ): super().__init__() self.record_keys = record_keys self.src_vectorizer = src_vectorizer self.tgt_vectorizer = tgt_vectorizer self.pattern = pattern self.directory = directory self.vocab = vocabs self.samples = 0 self.rank = 0 self.world_size = 1 self.shuffle = shuffle self.last_turn_only = last_turn_only self.distribute = distribute if torch.distributed.is_initialized() and distribute: self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() if os.path.exists(f"{directory}/md.yml"): f = read_yaml(f"{directory}/md.yml") self.samples = f['num_samples'] else: files = list(glob.glob(f"{directory}/{self.pattern}")) pg = create_progress_bar(len(files)) for file in pg(files): with open(file) as rf: for _ in rf: self.samples += 1 write_yaml({'num_samples': self.samples}, f"{directory}/md.yml")
def __init__(self, directory, pattern, vocabs, vectorizers, nctx): super().__init__() self.src_vectorizer = vectorizers['src'] self.tgt_vectorizer = vectorizers['tgt'] self.pattern = pattern self.nctx = nctx self.directory = directory self.vocab = vocabs self.samples = 0 self.rank = 0 self.world_size = 1 if torch.distributed.is_initialized(): self.rank = torch.distributed.get_rank() self.world_size = torch.distributed.get_world_size() if os.path.exists(f"{directory}/md.yml"): f = read_yaml(f"{directory}/md.yml") self.samples = f['num_samples'] else: files = list(glob.glob(f"{directory}/{self.pattern}")) pg = create_progress_bar(len(files)) for file in pg(files): with open(file) as rf: for _ in rf: self.samples += 1 write_yaml({'num_samples': self.samples}, f"{directory}/md.yml")
def _job_def(id_): try: job_loc = _get_job_file(id_, 'main.yml') object = read_yaml(job_loc) tasks = [] for t in object['tasks']: if 'mounts' not in t: t['mounts'] = t.pop('mount') name = t['name'] t['id'] = f'{id_}--{name}' task_def = TaskDefinition(**t) tasks.append(task_def) job_def = JobDefinition(tasks=tasks, location=job_loc, name=id_, id=id_, configs=_get_job_files(job_loc)) return job_def except Exception as e: LOGGER.error(e) return None
def test_read_yaml_strict(): pytest.importorskip("yaml") with pytest.raises(IOError): read_yaml(os.path.join("not_there.yml"), strict=True)
def test_read_yaml_given_default(): pytest.importorskip("yaml") gold_default = "default" data = read_yaml(os.path.join("not_there.yml"), gold_default) assert data == gold_default
def test_read_yaml_default_value(): pytest.importorskip("yaml") gold_default = {} data = read_yaml(os.path.join(data_loc, "not_there.yml")) assert data == gold_default
def test_read_yaml(gold_data): pytest.importorskip("yaml") data = read_yaml(os.path.join(data_loc, "test_yaml.yml")) assert data == gold_data
def get_num_samples(sample_md): yml = read_yaml(sample_md) return yml['num_samples']