def train(self): set_random_seeds(self.random_seed) all_langs = self.data_loader.all_langs num_langs = len(all_langs) idx = list(range(num_langs)) random.shuffle(idx) num_langs_per_fold = (num_langs + self.k_fold - 1) // self.k_fold accum_metrics = Metrics() for fold in range(self.k_fold): # Get train-dev split. start_idx = fold * num_langs_per_fold end_idx = start_idx + num_langs_per_fold if fold < self.k_fold - 1 else num_langs dev_langs = [all_langs[idx[i]] for i in range(start_idx, end_idx)] logging.imp(f'dev_langs: {sorted(dev_langs)}') train_langs = [all_langs[idx[i]] for i in range(num_langs) if i < start_idx or i >= end_idx] assert len(set(dev_langs + train_langs)) == num_langs self.trainer.reset() best_mse = self.trainer.train( self.evaluator, train_langs, dev_langs, fold) # Aggregate every fold. accum_metrics += best_mse logging.info(accum_metrics.get_table())
def train(self, evaluator: evaluator.Evaluator, train_langs: List[str], dev_langs: List[str], fold_idx: int) -> Metrics: # Reset parameters. self._init_params(init_matrix=True, init_vector=True, init_higher_tensor=True) self.optimizer = optim.Adam(self.model.parameters(), self.learning_rate) # Main boy. accum_metrics = Metrics() best_mse = None while not self.tracker.is_finished: # Get data first. metrics = self.train_loop(train_langs) accum_metrics += metrics self.tracker.update() self.check_metrics(accum_metrics) if self.track % self.save_interval == 0: self.save(dev_langs, f'{fold_idx}.latest') dev_metrics = evaluator.evaluate(dev_langs) logging.info(dev_metrics.get_table(title='dev')) if best_mse is None or dev_metrics.mse.mean < best_mse: best_mse = dev_metrics.mse.mean logging.imp(f'Updated best mse score: {best_mse:.3f}') self.save(dev_langs, f'{fold_idx}.best') return Metric('best_mse', best_mse, 1)
def _save(self, path: Path): to_save = { 'model': self.model.state_dict(), 'g': g.state_dict() } torch.save(to_save, path) logging.imp(f'Model saved to {path}.')
def save(self, eval_metrics: Metrics): new_value = eval_metrics.loss.mean self.save_to(g.log_dir / 'saved.latest') if self.tracker.update('best_loss', value=new_value): out_path = g.log_dir / 'saved.best' logging.imp(f'Best model updated: new best is {self.tracker.best_loss:.3f}') self.save_to(out_path)
def _evaluate_one_data_loader(self, dl: ContinuousTextDataLoader, stage: stage) -> Metrics: task = dl.task accum_metrics = Metrics() # Get all metrics from batches. dfs = list() total_num_samples = 0 for batch in dl: if g.eval_max_num_samples and total_num_samples + batch.batch_size > g.eval_max_num_samples: logging.imp( f'Stopping at {total_num_samples} < {g.eval_max_num_samples} evaluated examples from {task}.' ) break model_ret = self.model(batch) batch_metrics, batch_df = self.predict(model_ret, batch) accum_metrics += batch_metrics # accum_metrics += self.analyzer.analyze(model_ret, batch) total_num_samples += batch.batch_size dfs.append(batch_df) df = pd.concat(dfs, axis=0) # Write the predictions to file. out_path = g.log_dir / 'predictions' / f'{task}.{stage}.tsv' out_path.parent.mkdir(exist_ok=True, parents=True) df.to_csv(out_path, index=None, sep='\t') # Compute P/R/F scores. accum_metrics += get_prf_scores(accum_metrics) return accum_metrics
def step(self): """Copied from https://pytorch.org/docs/stable/_modules/torch/optim/lr_scheduler.html#ReduceLROnPlateau.""" for i, param_group in enumerate(self.optimizer.param_groups): old_lr = float(param_group['lr']) new_lr = old_lr * self._factor param_group['lr'] = new_lr logging.imp(f'Learning rate is now {new_lr:.4f}.')
def _get_model(self, dl=None): phono_feat_mat = special_ids = None if g.use_phono_features: phono_feat_mat = get_tensor(self.src_abc.pfm) special_ids = get_tensor(self.src_abc.special_ids) phono_kwargs = { 'phono_feat_mat': phono_feat_mat, 'special_ids': special_ids } if g.use_rl: end_state = self.env.end agent_cls = VanillaPolicyGradient if g.agent == 'vpg' else A2C model = agent_cls(len(self.tgt_abc), self.env, end_state, **phono_kwargs) else: model = OnePairModel(len(self.src_abc), len(self.tgt_abc), **phono_kwargs) if g.saved_model_path is not None: model.load_state_dict( torch.load(g.saved_model_path, map_location=torch.device('cpu'))) logging.imp(f'Loaded from {g.saved_model_path}.') if has_gpus(): model.cuda() logging.info(model) return model
def prepare_dataset(self, setting: Setting) -> OnePairDataset: pair = (setting.src_lang, setting.tgt_lang) if pair not in self._pair2dfs: raise RuntimeError(f'Pair {pair} not added.') # Get alphabets first. def check_abc(lang: Lang): if lang not in self._lang2abc: raise RuntimeError(f'Alphabet for {lang} has not been prepared.') check_abc(setting.src_lang) check_abc(setting.tgt_lang) # Get relevant data frames. src_df, tgt_df = self._pair2dfs[pair] src_df = setting.split.select(src_df) tgt_df = setting.split.select(tgt_df) pair_df = pd.merge(src_df, tgt_df, left_index=True, right_index=True, suffixes=('_src', '_tgt')) if setting.keep_ratio is not None: logging.imp(f'keep_ratio is {setting.keep_ratio}.') num = int(len(pair_df) * setting.keep_ratio) pair_df = pair_df.sample(num, random_state=g.random_seed) vc = pair_df.index.value_counts() vc.name = 'num_variants' pair_df = pd.merge(pair_df, vc, left_index=True, right_index=True) pair_df['sample_weight'] = 1 / vc logging.info(f'Total number of cognate pairs for {pair} for {setting.split}: {len(pair_df)}.') return OnePairDataset(setting, pair_df)
def save_alignment(self): to_save = { 'unit_aligner': self.model.g2p.unit_aligner.state_dict(), } path = g.log_dir / f'saved.{self.stage}.almt' torch.save(to_save, path) logging.imp(f'Alignment saved to {path}.')
def save(self, eval_metrics: Metrics): self.save_to(g.log_dir / 'saved.latest') # self.tracker.update('best_loss', value=eval_metrics.dev_total_loss.mean) if self.tracker.update('best_f1', value=eval_metrics.dev_prf_f1.value): out_path = g.log_dir / f'saved.best' logging.imp(f'Best model updated: new best is {self.tracker.best_f1:.3f}') self.save_to(out_path)
def save(self, eval_metrics: Metrics): self.save_to(g.log_dir / f'saved.{self.stage}.latest') if eval_metrics is not None: if self.tracker.update('best_f1', value=eval_metrics.prf_exact_span_f1.value): out_path = g.log_dir / f'saved.{self.stage}.best' logging.warning('Do NOT use this number since this f1 is compared against ground truths.') logging.imp(f'Best model updated: new best is {self.tracker.best_f1:.3f}') self.save_to(out_path)
def save(self, dev_langs: List[str], suffix: str): out_path = self.log_dir / f'saved.{suffix}' to_save = { 'model': self.model.state_dict(), 'g': g.state_dict(), 'dev_langs': dev_langs, } torch.save(to_save, out_path) logging.imp(f'Model saved to {out_path}.')
def save(self): super().save() if self.track % self.save_interval == 0: metrics = self.evaluator.evaluate() logging.info(f'New evaluation metrics is {metrics.loss.mean:.3f}.') if self.best_metrics is None or metrics.loss.mean < self.best_metrics.loss.mean: self.best_metrics = metrics out_path = self.log_dir / 'saved.best' logging.imp(f'Best model updated: new best is {self.best_metrics.loss.mean:.3f}.') self._save(out_path)
def load(self, path: Path, load_lm_model: bool = False, load_optimizer_state: bool = False, load_phi_scorer: bool = False): saved = torch.load(path) smsd = saved['model'] if not load_lm_model: smsd = {k: v for k, v in smsd.items() if not k.startswith('lm_model')} if not load_phi_scorer: smsd = {k: v for k, v in smsd.items() if not k.startswith('phi_scorer')} self.model.load_state_dict(smsd, strict=False) if load_optimizer_state: self.optimizer.load_state_dict(saved['optimizer']) logging.imp(f'Loading model from {path}.')
def should_terminate(self, eval_metrics: Metrics) -> bool: if not self._update_best_score(eval_metrics): logging.imp('eval_reward has not been improved.') self.tracker.update('tolerance') if g.improved_player_only: logging.imp('Loading old state dict.') self.agent.load_state_dict(self._old_state) else: self.tracker.reset('tolerance') self.metric_writer.add_metrics(self.best_metrics, self.tracker['step'].value) return self.tracker.is_finished('tolerance')
def __init__(self, lu_size: int, ku_size: int): """`lu_size`: number of lost units, `ku_size`: number of known units.""" super().__init__() self.unit_aligner = nn.Embedding(lu_size, ku_size) logging.imp('Unit aligner initialized to 0.') self.unit_aligner.weight.data.fill_(0.0) self.conv = nn.Conv1d(g.dim, g.dim, g.g2p_window_size, padding=g.g2p_window_size // 2) self.dropout = nn.Dropout(g.dropout)
def load(self): ckpt = torch.load(self.saved_path) def try_load(name): src = ckpt[name] dest = getattr(self, name) try: dest.load_state_dict(src) except RuntimeError as e: logging.error(e) try_load('model') try_load('optimizer') try_load('tracker') try_load('flow') logging.imp(f'Loaded saved states from {self.saved_path}')
def evaluate(self, stage: str) -> Metrics: segments = list() predictions = list() ground_truths = list() matched_segments = list() total_num_samples = 0 analyzed_metrics = Metrics() for batch in pbar(self.dl, desc='eval_batch'): if g.eval_max_num_samples and total_num_samples + batch.batch_size > g.eval_max_num_samples: logging.imp( f'Stopping at {total_num_samples} < {g.eval_max_num_samples} evaluated examples.' ) break ret = self.model(batch) analyzed_metrics += self.analyzer.analyze(ret, batch) segments.extend(list(batch.segments)) segmentations, _matched_segments = self._get_segmentations( ret, batch) predictions.extend(segmentations) matched_segments.extend(_matched_segments) ground_truths.extend( [segment.to_segmentation() for segment in batch.segments]) total_num_samples += batch.batch_size df = _get_df(segments, ground_truths, predictions, matched_segments, columns=('segment', 'ground_truth', 'prediction', 'matched_segment')) out_path = g.log_dir / 'predictions' / f'extract.{stage}.tsv' out_path.parent.mkdir(exist_ok=True, parents=True) df.to_csv(out_path, index=None, sep='\t') matching_stats = get_matching_stats(predictions, ground_truths) prf_scores = get_prf_scores(matching_stats) return analyzed_metrics + matching_stats + prf_scores
def add_file(self, lang: Lang, path: Path) -> DF: """Read the file at `path` for language `lang`.""" if lang in self._lang2abc: raise RuntimeError( f'An alphabet has already been prepared for language {lang}. No more files can be added.') # Always use the resolved path as the key. path = path.resolve() # Skip this process if this file has already been added. if lang in self._lang2dfs and path in self._lang2dfs[lang]: logging.warn(f'File at {path} has already been added for language {lang}.') # Actually add this file. else: df = pd.read_csv(str(path), sep='\t', keep_default_na=True) df['cognate_id'] = range(len(df)) if g.input_format == 'wikt': df = df.copy() df['tokens'] = df['tokens'].str.split('|') df = df.explode('tokens') # In-place to add a unit_seq column to store preprocessed data. col = 'tokens' if g.input_format == 'wikt' else 'parsed_tokens' # Use parsed tokens if possible. df['pre_unit_seq'] = df[col].str.split().apply(_preprocess) # NOTE(j_luo) Add noise to the target tokens by randomly duplicating one character. if g.noise_level > 0.0: logging.imp(f'Adding noise to the target tokens with level {g.noise_level}.') random.seed(g.random_seed) def add_noise(token): if random.random() < g.noise_level: pos = random.randint(0, len(token) - 1) token = token[:pos] + [token[pos]] + token[pos:] return token df['pre_unit_seq'] = df['pre_unit_seq'].apply(add_noise) df = df.set_index('cognate_id') self._lang2dfs[lang][path] = df logging.info(f'File at {path} has been added for language {lang}.') return self._lang2dfs[lang][path]
def evaluate( self, *args ) -> Metrics: # HACK(j_luo) *args is used just comply with BaseTrainer function signature. with torch.no_grad(): self.model.eval() all_metrics = Metrics() total_num_samples = 0 for batch in self.data_loader: if g.eval_max_num_samples and total_num_samples + batch.batch_size > g.eval_max_num_samples: logging.imp( f'Stopping at {total_num_samples} < {g.eval_max_num_samples} evaluated examples from.' ) break scores = self.model.score(batch) try: metrics = self.analyzer.analyze(scores.distr) except AttributeError: metrics = self.analyzer.analyze(scores) all_metrics += metrics total_num_samples += batch.batch_size return all_metrics
def _check_acc(self, flow): preds = flow.get_best(nonzero=True) # Checking lost. acc = sum([has_cognate(w, self.known_lang) for w in preds.keys()]) rate = acc / len(preds) logging.imp( f'Accuracy on the lost side {acc} / {len(preds)} = {rate:.3f} ') # Checking known. acc = sum([has_cognate(w, self.lost_lang) for w in preds.values()]) rate = acc / len(preds) logging.imp( f'Accuracy on the known side {acc} / {len(preds)} = {rate:.3f} ') # Checking lost and known. acc = sum([is_cognate(w1, w2) for w1, w2 in preds.items()]) rate = acc / len(preds) logging.imp( f'Accuracy for lost-known {acc} / {len(preds)} = {rate:.3f} ')
def ins_del_cost(self, value): logging.imp(f'Setting ins_del_cost to {value}.')
def __init__(self): all_tgt, self.cog_reg, self.src_abc, self.tgt_abc = self.prepare_raw_data( ) # Get stats for unseen units. stats = self.tgt_abc.stats _, test_tgt_path = get_paths(g.data_path, g.src_lang, g.tgt_lang) mask = (stats.sum() == stats.loc[test_tgt_path]) unseen = mask[mask].index.tolist() total = len(stats.loc[test_tgt_path].dropna()) logging.info( f'Unseen units ({len(unseen)}/{total}) for {g.tgt_lang} are: {unseen}.' ) # Get language-to-id mappings. Used only for the targets (i.e., decoder side). self.lang2id = lang2id = {tgt: i for i, tgt in enumerate(all_tgt)} # Get all data loaders. self.dl_reg = DataLoaderRegistry() def create_setting(name: str, tgt_lang: str, split: Split, for_training: bool, keep_ratio: Optional[float] = None, tgt_sot: bool = False) -> Setting: return Setting(name, 'one_pair', split, g.src_lang, tgt_lang, for_training, keep_ratio=keep_ratio, tgt_sot=tgt_sot) test_setting = create_setting(f'test@{g.tgt_lang}', g.tgt_lang, Split('all'), False, keep_ratio=g.test_keep_ratio) settings: List[Setting] = [test_setting] # Get the training languages. for train_tgt_lang in g.train_tgt_langs: if g.input_format == 'ielex': train_split = Split( 'train', [1, 2, 3, 4]) # Use the first four folds for training. dev_split = Split('dev', [5]) # Use the last fold for dev. else: train_split = Split('train') dev_split = Split('dev') train_setting = create_setting(f'train@{train_tgt_lang}', train_tgt_lang, train_split, True, keep_ratio=g.keep_ratio) train_e_setting = create_setting(f'train@{train_tgt_lang}_e', train_tgt_lang, train_split, False, keep_ratio=g.keep_ratio) dev_setting = create_setting(f'dev@{train_tgt_lang}', train_tgt_lang, dev_split, False) test_setting = create_setting(f'test@{train_tgt_lang}', train_tgt_lang, Split('test'), False) settings.extend( [train_setting, train_e_setting, dev_setting, test_setting]) for setting in settings: self.dl_reg.register_data_loader(setting, self.cog_reg, lang2id=lang2id) phono_feat_mat = special_ids = None if g.use_phono_features: phono_feat_mat = get_tensor(self.src_abc.pfm) special_ids = get_tensor(self.src_abc.special_ids) self.model = OneToManyModel(len(self.src_abc), len(self.tgt_abc), len(g.train_tgt_langs) + 1, lang2id[g.tgt_lang], lang2id=lang2id, phono_feat_mat=phono_feat_mat, special_ids=special_ids) if g.saved_model_path is not None: self.model.load_state_dict( torch.load(g.saved_model_path, map_location=torch.device('cpu'))) logging.imp(f'Loaded from {g.saved_model_path}.') if has_gpus(): self.model.cuda() logging.info(self.model) metric_writer = MetricWriter(g.log_dir, flush_secs=5) # NOTE(j_luo) Evaluate on every loader that is not for training. eval_dls = self.dl_reg.get_loaders_by_name( lambda name: 'train' not in name or '_e' in name) self.evaluator = Evaluator(self.model, eval_dls, self.tgt_abc, metric_writer=metric_writer) if not g.evaluate_only: train_names = [ f'train@{train_tgt_lang}' for train_tgt_lang in g.train_tgt_langs ] train_settings = [ self.dl_reg.get_setting_by_name(name) for name in train_names ] self.trainer = Trainer(self.model, train_settings, [1.0] * len(train_settings), 'step', stage_tnames=['step'], evaluator=self.evaluator, check_interval=g.check_interval, eval_interval=g.eval_interval, save_interval=g.save_interval, metric_writer=metric_writer) if g.saved_model_path is None: # self.trainer.init_params('uniform', -0.1, 0.1) self.trainer.init_params('xavier_uniform') optim_cls = Adam if g.optim_cls == 'adam' else SGD self.trainer.set_optimizer(optim_cls, lr=g.learning_rate)
def run(self): metric_writer = MetricWriter(g.log_dir, flush_secs=5) def get_trainer(model, train_name, evaluator, metric_writer, **kwargs): if g.use_rl: # if g.use_mcts: trainer_cls = MctsTrainer # else: # trainer_cls = PolicyGradientTrainer else: trainer_cls = Trainer trainer = trainer_cls( model, [self.dl_reg.get_setting_by_name(train_name)], [1.0], 'step', stage_tnames=['step'], check_interval=g.check_interval, evaluator=evaluator, eval_interval=g.eval_interval, save_interval=g.save_interval, metric_writer=metric_writer, **kwargs) if g.saved_model_path is None: # trainer.init_params('uniform', -0.1, 0.1) trainer.init_params('xavier_uniform') optim_cls = Adam if g.optim_cls == 'adam' else SGD optim_kwargs = dict() if optim_cls == SGD: optim_kwargs['momentum'] = 0.9 if not g.use_rl or g.use_mcts or (g.agent == 'a2c' and g.value_steps == 0): trainer.set_optimizer(optim_cls, lr=g.learning_rate, weight_decay=g.weight_decay, **optim_kwargs) else: trainer.set_optimizer(optim_cls, name='policy', mod=model.policy_net, lr=g.learning_rate, **optim_kwargs) # , weight_decay=1e-4) if g.agent == 'a2c': trainer.set_optimizer( optim_cls, name='value', mod=model.value_net, lr=g.value_learning_rate, **optim_kwargs) # , weight_decay=1e-4) return trainer def run_once(train_name, dev_name, test_name): train_e_dl = self.dl_reg[f'{train_name}_e'] dev_dl = self.dl_reg[dev_name] test_dl = self.dl_reg[test_name] model = self._get_model() evaluator = Evaluator(model, { train_name: train_e_dl, dev_name: dev_dl, test_name: test_dl }, self.tgt_abc, metric_writer=metric_writer) if g.evaluate_only: # TODO(j_luo) load global_step from saved model. evaluator.evaluate('evaluate_only', 0) else: trainer = get_trainer(model, train_name, evaluator, metric_writer) trainer.train(self.dl_reg) if g.use_rl: # mcts.set_logging_options(g.mcts_verbose_level, g.mcts_log_to_file) evaluator = MctsEvaluator(self.mcts, metric_writer=metric_writer) if g.evaluate_only: evaluator.evaluate('evaluate_only', 0) return else: trainer = get_trainer(self.model, 'rl', evaluator, metric_writer, mcts=self.mcts) # else: # collector = TrajectoryCollector(g.batch_size, # max_rollout_length=g.max_rollout_length, # truncate_last=True) # trainer = get_trainer(model, 'rl', None, None, env=self.env, collector=collector) trainer.train(self.dl_reg) elif g.input_format == 'wikt': run_once('train', 'dev', 'test') else: for fold in range(5): logging.imp(f'Cross-validation, fold number {fold}') run_once(f'train@{fold}', f'dev@{fold}', 'test')
def threshold(self, value): logging.imp(f'Setting threshold to {value}.')
def load(self, path: Path): saved = torch.load(path) smsd = saved['model'] self.model.load_state_dict(smsd) logging.imp(f'Loading model from {path}.')
def __init__(self, lang: str, contents: List[List[str]], sources: Union[str, List[str]], dist_mat: Optional[NDA] = None, edges: Optional[List[Tuple[str, str]]] = None, cl_map: Optional[Dict[str, str]] = None, gb_map: Optional[Dict[str, str]] = None): if sources is not None: if isinstance(sources, str): sources = [sources] * len(contents) else: assert len(contents) == len(sources) else: sources = ['unknown'] * len(contents) cnt = defaultdict(Counter) for content, source in zip(contents, sources): for u in content: cnt[u][source] += 1 # Merge symbols with identical phonological features if needed. if not g.use_mcts and not g.use_duplicate_phono and g.use_phono_features: t2u = defaultdict(list) # tuple-to-units for u in cnt: t = tuple(self.get_pfv(u).numpy()) t2u[t].append(u) u2u = dict() # unit-to-unit. This finds the standardized unit. for units in t2u.values(): lengths = [len(u) for u in units] min_i = lengths.index(min(lengths)) std_u = units[min_i] for u in units: u2u[u] = std_u merged_cnt = defaultdict(Counter) for u, std_u in u2u.items(): merged_cnt[std_u].update(cnt[u]) logging.imp( f'Symbols are merged based on phonological features: from {len(cnt)} to {len(merged_cnt)}.' ) cnt = merged_cnt self._u2u = u2u units = sorted(cnt.keys()) base_n = len(units) # Expand vowel set by adding stress. processor = FeatureProcessor() for u in list(units): seg = processor.process(u) if isinstance(seg, Nphthong) or seg.is_vowel(): units.append(u + '{+}') units.append(u + '{-}') self.special_units = [ SOT, EOT, PAD, ANY, EMP, SYL_EOT, ANY_S, ANY_UNS, NULL ] self.special_ids = [ SOT_ID, EOT_ID, PAD_ID, ANY_ID, EMP_ID, SYL_EOT_ID, ANY_S_ID, ANY_UNS_ID, NULL_ID ] special_n = len(self.special_ids) self._id2unit = self.special_units + units self._unit2id = dict(zip(self.special_units, self.special_ids)) self._unit2id.update( {c: i for i, c in enumerate(units, len(self.special_units))}) # Get vowel info. n = len(self._id2unit) self.is_vowel = np.zeros(n, dtype=bool) self.is_consonant = np.zeros(n, dtype=bool) self.unit_stress = np.zeros(n, dtype='int32') self.unit2base = np.arange(n, dtype='uint16') self.unit2stressed = np.arange(n, dtype='uint16') self.unit2unstressed = np.arange(n, dtype='uint16') self.unit_stress.fill(mcts_cpp.PyNoStress) self.unit_stress[ANY_S_ID] = mcts_cpp.PyStressed self.unit_stress[ANY_UNS_ID] = mcts_cpp.PyUnstressed self.unit2base[ANY_S_ID] = ANY_ID self.unit2base[ANY_UNS_ID] = ANY_ID self.unit2stressed[ANY_ID] = ANY_S_ID self.unit2unstressed[ANY_ID] = ANY_UNS_ID for u in self._id2unit: if u.endswith('{+}') or u.endswith('{-}'): base = u[:-3] base_id = self._unit2id[base] i = self._unit2id[u] self.is_vowel[base_id] = True self.is_vowel[i] = True self.unit2base[i] = base_id self.unit_stress[i] = mcts_cpp.PyStressed if u[ -2] == '+' else mcts_cpp.PyUnstressed if u.endswith('{+}'): self.unit2stressed[base_id] = i else: self.unit2unstressed[base_id] = i for u in units: if not u.endswith('{+}') and not u.endswith('{-}'): self.is_consonant[self._unit2id[u]] = True self.stats: pd.DataFrame = pd.DataFrame.from_dict(cnt) self.dist_mat = self.edges = self.cl_map = self.gb_map = None if dist_mat is not None: # Pad the dist_mat for special units. self.dist_mat = np.full([len(self), len(self)], 99999, dtype='float32') # NOTE(j_luo) Special ids should have zero cost if matched. for i in range(special_n): self.dist_mat[i, i] = 0 # NOTE(j_luo) The new dist_mat should account for both the base units and the expanded vowels with stress. orig_units = contents[0] orig_u2i = {u: i for i, u in enumerate(orig_units)} new_ids = np.asarray([self[u] for u in orig_units] + [self[u] for u in units[base_n:]]) orig_ids = np.asarray( list(range(len(orig_units))) + [ orig_u2i[self[self.unit2base[self[u]]]] for u in units[base_n:] ]) self.dist_mat[new_ids.reshape(-1, 1), new_ids] = dist_mat[orig_ids.reshape(-1, 1), orig_ids] self.edges = edges self.cl_map = cl_map self.gb_map = gb_map logging.info( f'Alphabet for {lang}, size {len(self._id2unit)}: {self._id2unit}.' ) self.lang = lang assert (len(self) < SENTINEL_ID)