class Encoder(nn.Module): add_argument('window_size', default=3, dtype=int, msg='window size for the cnn kernel') add_argument('dense_input', default=False, dtype=bool, msg='flag to dense input feature matrices') def __init__(self, dropout: float = 0.0, hidden_size: int = 10, dim: int = 10): super().__init__() self.dropout = dropout self.hidden_size = hidden_size self.dim = dim emb_cls = DenseFeatEmbedding if g.dense_input else FeatEmbedding self.feat_embedding = emb_cls('feat_emb', 'chosen_feat_group', 'char_emb', dim=dim) self.cat_dim = self.dim * self.feat_embedding.effective_num_feature_groups self._get_core_layers() def _get_core_layers(self): # IDEA(j_luo) should I define a Rename layer? self.conv_layers = nn.Sequential( nn.Conv1d(self.cat_dim, self.cat_dim, g.window_size, padding=g.window_size // 2)) self.linear = nn.Linear(self.cat_dim, self.hidden_size) def forward(self, feat_matrix: LT, pos_to_predict: LT, source_padding: BT) -> FT: bs = source_padding.size('batch') l = source_padding.size('length') batch_i = get_range(bs, 1, 0) feat_emb = self.feat_embedding(feat_matrix, source_padding, masked_positions=pos_to_predict) feat_emb = feat_emb.align_to('batch', 'char_emb', 'length') output = self.conv_layers(feat_emb.rename(None)) output = output.refine_names('batch', 'char_conv_repr', 'length') # size: bs x D x l output = self.linear(output.align_to( ..., 'char_conv_repr')) # size: bs x l x n_hid output = output.refine_names('batch', 'length', 'hidden_repr') output = nn.functional.leaky_relu(output, negative_slope=0.1) # NOTE(j_luo) This is actually quite wasteful because we are discarding all the irrelevant information, which is computed anyway. This is equivalent to training on ngrams. with NoName(output, pos_to_predict): h = output[batch_i, pos_to_predict] h = h.refine_names('batch', 'hidden_repr') # size: bs x n_hid return h
class DecipherTrainer(BaseTrainer): add_argument('score_per_word', default=1.0, dtype=float, msg='score added for each word') add_argument('concentration', default=1e-2, dtype=float, msg='concentration hyperparameter') add_argument('supervised', dtype=bool, default=False, msg='supervised mode') # add_argument('mode', default='local-supervised', dtype=str, # choices=['local-supervised', 'global-supervised'], msg='training mode') add_argument('mlm_coeff', dtype=float, default=0.05, msg='Flag to use mlm loss.') add_argument('warmup_updates', dtype=int, default=4000, msg='Number of warmup updates for Adam.') model: DecipherModel def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.analyzer = DecipherAnalyzer() self.set_optimizer() def set_optimizer(self): super().set_optimizer(AdamInverseSqrtWithWarmup, lr=g.learning_rate, betas=(0.9, 0.98), warmup_updates=g.warmup_updates) def add_trackables(self): self.tracker.add_trackable('total_step', total=g.num_steps) self.tracker.add_max_trackable('best_f1') def train_one_step(self, dl: ContinuousTextDataLoader) -> Metrics: self.model.train() self.optimizer.zero_grad() batch = dl.get_next_batch() ret = self.model(batch) metrics = self.analyzer.analyze(ret, batch) metrics.total_loss.mean.backward() self.optimizer.step() grad_norm = clip_grad_norm_(self.model.parameters(), 5.0) weight = (~batch.source_padding).sum() metrics += Metric('grad_norm', grad_norm * weight, weight) return metrics.with_prefix_('decipher') def load(self, path: Path, load_lm_model: bool = False, load_optimizer_state: bool = False, load_phi_scorer: bool = False): saved = torch.load(path) smsd = saved['model'] if not load_lm_model: smsd = {k: v for k, v in smsd.items() if not k.startswith('lm_model')} if not load_phi_scorer: smsd = {k: v for k, v in smsd.items() if not k.startswith('phi_scorer')} self.model.load_state_dict(smsd, strict=False) if load_optimizer_state: self.optimizer.load_state_dict(saved['optimizer']) logging.imp(f'Loading model from {path}.') def save(self, eval_metrics: Metrics): self.save_to(g.log_dir / 'saved.latest') # self.tracker.update('best_loss', value=eval_metrics.dev_total_loss.mean) if self.tracker.update('best_f1', value=eval_metrics.dev_prf_f1.value): out_path = g.log_dir / f'saved.best' logging.imp(f'Best model updated: new best is {self.tracker.best_f1:.3f}') self.save_to(out_path)
class BaseTrainer(BaseTrainerDev, metaclass=ABCMeta): # pylint: disable=abstract-method add_argument('num_steps', default=10, dtype=int, msg='number of steps to train') add_argument('learning_rate', default=2e-3, dtype=float, msg='learning rate') add_argument('check_interval', default=2, dtype=int, msg='check metrics after this many steps') add_argument('eval_interval', default=500, dtype=int, msg='save models after this many steps') add_argument('save_interval', default=0, dtype=int, msg='save models after this many steps') def save_to(self, path: Path): to_save = { 'model': self.model.state_dict(), 'g': g.state_dict(), 'optimizer': self.optimizer.state_dict() } torch.save(to_save, path) logging.imp(f'Model saved to {path}.')
class ExtractManager(BaseManager): # IDEA(j_luo) when to put this in manager/trainer? what about scheduler? annealing? restarting? Probably all in trainer -- you need to track them with pbars. add_argument('optim_cls', default='adam', dtype=str, choices=['adam', 'adagrad', 'sgd'], msg='Optimizer class.') add_argument('anneal_factor', default=0.5, dtype=float, msg='Mulplication value for annealing.') add_argument('min_threshold', default=0.01, dtype=float, msg='Min value for threshold') _name2cls = {'adam': Adam, 'adagrad': Adagrad, 'sgd': SGD} def __init__(self): task = ExtractTask() self.dl_reg = DataLoaderRegistry() self.dl_reg.register_data_loader(task, g.data_path) lu_size = None if g.input_format == 'text': lu_size = self.dl_reg[task].dataset.unit_vocab_size self.model = ExtractModel(lu_size=lu_size) if has_gpus(): self.model.cuda() self.evaluator = ExtractEvaluator(self.model, self.dl_reg[task]) self.trainer = ExtractTrainer(self.model, [task], [1.0], 'total_step', stage_tnames=['round', 'total_step'], evaluator=self.evaluator, check_interval=g.check_interval, eval_interval=g.eval_interval, save_interval=g.save_interval) if g.saved_model_path: self.trainer.load(g.saved_model_path) # self.trainer.set_optimizer(Adam, lr=g.learning_rate) def run(self): optim_cls = self._name2cls[g.optim_cls] self.trainer.threshold = g.init_threshold self.trainer.set_optimizer(optim_cls, lr=g.learning_rate) # , momentum=0.9, nesterov=True) # Save init parameters. out_path = g.log_dir / f'saved.init' self.trainer.save_to(out_path) while self.trainer.threshold > g.min_threshold: self.trainer.reset() self.trainer.set_optimizer(optim_cls, lr=g.learning_rate) self.trainer.train(self.dl_reg) self.trainer.tracker.update('round')
class DecipherManager(BaseManager): add_argument('dev_data_path', dtype='path', msg='Path to dev data.') add_argument('aux_train_data_path', dtype='path', msg='Path to aux train data.') add_argument('in_domain_dev_data_path', dtype='path', msg='Path to in-domain dev data.') add_argument('saved_path', dtype='path') add_argument('saved_model_path', dtype='path', msg='Path to a saved model, skipping the local training phase.') add_argument('train_phi', dtype=bool, default=False, msg='Flag to train phi score. Used only with supervised mode.') add_argument('fix_phi', dtype=bool, default=False, msg='Flag fix phi scorer.') # add_argument('use_mlm_loss', dtype=bool, default=False, msg='Flag to use mlm loss.') def __init__(self): self.model = DecipherModel() if has_gpus(): self.model.cuda() train_task = DecipherTask('train') dev_task = DecipherTask('dev') self.dl_reg = DataLoaderRegistry() eval_tasks = [train_task, dev_task] if g.in_domain_dev_data_path: in_domain_dev_task = DecipherTask('in_domain_dev') self.dl_reg.register_data_loader(in_domain_dev_task, g.in_domain_dev_data_path) eval_tasks.append(in_domain_dev_task) train_tasks = [train_task] if g.aux_train_data_path: aux_train_task = DecipherTask('aux_train') self.dl_reg.register_data_loader(aux_train_task, g.aux_train_data_path) train_tasks.append(aux_train_task) self.dl_reg.register_data_loader(train_task, g.data_path) self.dl_reg.register_data_loader(dev_task, g.dev_data_path) self.evaluator = DecipherEvaluator(self.model, self.dl_reg, eval_tasks) self.trainer = DecipherTrainer(self.model, train_tasks, [1.0] * len(train_tasks), 'total_step', evaluator=self.evaluator, check_interval=g.check_interval, eval_interval=g.eval_interval) if g.train_phi: freeze(self.model.self_attn_layers) freeze(self.model.positional_embedding) freeze(self.model.emb_for_label) freeze(self.model.label_predictor) if g.saved_model_path: self.trainer.load(g.saved_model_path, load_phi_scorer=True) if g.fix_phi: freeze(self.model.phi_scorer) # freeze(self.model.self_attn_layers) # freeze(self.model.positional_embedding) # freeze(self.model.emb_for_label) # freeze(self.model.label_predictor) self.trainer.set_optimizer() def run(self): self.trainer.train(self.dl_reg)
class LMTrainer(BaseTrainer): add_argument('feat_groups', default='pcvdst', dtype=str, msg='what to include during training: p(type), c(onstonant), v(vowel), d(iacritics), s(tress) and t(one).') model: LM analyzer_cls: ClassVar = LMAnalyzer def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # IDEA(j_luo) Preparing the trainer should be handled by the manager, not by __init__ call. logging.warning('Init model.') for p in get_trainable_params(self.model, named=False): if p.ndim == 2: torch.nn.init.xavier_uniform_(p) self.set_optimizer(optim.Adam, lr=g.learning_rate) self.analyzer = self.analyzer_cls() def add_trackables(self): self.tracker.add_trackable('total_step', total=g.num_steps) self.tracker.add_min_trackable('best_loss') def train_one_step(self, dl: IpaDataLoader) -> Metrics: self.model.train() self.optimizer.zero_grad() batch = dl.get_next_batch() ret = self.model.score(batch) # for idx, segment in enumerate(batch.segments): # if str(segment).startswith('e-s-t-a-n'): # break # from xib.ipa import Name # name = Name('Ptype', 'camel') # print(torch.stack([ret.distr[name][0], ret.distr_noise[name][0]], new_name='tmp')[idx]) # import time; time.sleep(1) metrics = self.analyzer.analyze(ret) metrics.loss.mean.backward() grad_norm = get_grad_norm(self.model) grad_norm = Metric('grad_norm', grad_norm * len(batch), len(batch)) metrics += grad_norm self.optimizer.step() return metrics def save(self, eval_metrics: Metrics): new_value = eval_metrics.loss.mean self.save_to(g.log_dir / 'saved.latest') if self.tracker.update('best_loss', value=new_value): out_path = g.log_dir / 'saved.best' logging.imp(f'Best model updated: new best is {self.tracker.best_loss:.3f}') self.save_to(out_path)
class BaseIpaDataLoader(BaseDataLoader, metaclass=ABCMeta): add_argument('data_path', dtype='path', msg='path to the feat data in tsv format.') add_argument('num_workers', default=0, dtype=int, msg='number of workers for the data loader') add_argument('char_per_batch', default=500, dtype=int, msg='batch_size') add_argument('new_style', default=False, dtype=bool, msg='flag to use new style ipa annotations') dataset_cls: Type[Dataset] def __init__(self, data_path: Path, task: Task): dataset = type(self).dataset_cls(data_path) batch_sampler = BatchSampler(dataset, shuffle=True) cls = type(self) super().__init__(dataset, task, batch_sampler=batch_sampler, pin_memory=True, num_workers=g.num_workers, collate_fn=collate_fn) @abstractmethod def _prepare_batch(self, collate_return: CollateReturn) -> BaseBatch: pass def __iter__(self): for collate_return in super().__iter__(): batch = self._prepare_batch(collate_return) yield batch.cuda()
def _prepare_batch(self, collate_return: CollateReturn) -> IpaBatch: cls = type(self) batch_cls = cls.batch_cls return batch_cls(collate_return.segments, collate_return.lengths, collate_return.matrices) class DenseIpaDataLoader(IpaDataLoader): batch_cls = DenseIpaBatch add_argument( 'max_segment_length', default=10, dtype=int, msg= 'Max length for segments. Longer ones will be broken down into moving windows.' ) add_argument('broken_words', default=False, dtype=bool, msg='Flag to break words down.') class UnbrokenIpaDataset(IpaDataset): cache_suffix = 'unbroken.cache' def load_data(self, data_path: Path): segment_dict = self._get_segment_dict(data_path)
class LM(nn.Module): add_argument('weighted_loss', default='', dtype=str, choices=['', 'mr', 'ot'], msg='what type of weighted loss to use') add_argument('use_cbow_encoder', dtype=bool, default=True, msg='Flag to use cbow encoder.') def _get_encoder(self, dropout: float = None, hidden_size: int = None, dim: int = None): dropout = dropout or g.dropout hidden_size = hidden_size or g.hidden_size dim = dim or g.dim encoder_cls = CbowEncoder if g.use_cbow_encoder else Encoder return encoder_cls(dropout=dropout, hidden_size=hidden_size, dim=dim) def __init__(self): super().__init__() self.encoder = self._get_encoder() self.predictor = Predictor() # if weighted_loss and not new_style: # raise ValueError('Must use new_style if using weighted loss') def forward(self, batch: IpaBatch) -> Dict[Cat, FT]: """ First encode the `feat_matrix` into a vector `h`, then based on it predict the distributions of features. """ h = self.encoder(batch.feat_matrix, batch.pos_to_predict, batch.source_padding) distr = self.predictor(h) return distr def score(self, batch) -> Dict[Cat, FT]: distr = self(batch) return self.score_distr(distr, batch) def score_distr(self, distr: Dict[Cat, FT], batch: IpaBatch) -> Dict[Cat, FT]: scores = dict() for name, output in distr.items(): i = get_index(name, new_style=g.new_style) target = batch.target_feat[:, i] weight = batch.target_weight[:, i] if g.weighted_loss == '': # log_probs = gather(output, target) log_probs = output.gather(name.value, target) score = -log_probs else: e = get_new_style_enum(i) mat = get_tensor(e.get_distance_matrix()) mat = mat[target.rename(None)] if g.weighted_loss == 'mr': mat_exp = torch.where(mat > 0, (mat + 1e-8).log(), get_zeros(mat.shape).fill_(-99.9)) logits = mat_exp + output # NOTE(j_luo) For the categories except Ptype, the sums of probs are not 1.0 (they are conditioned on certain values of Ptyle). # As a result, we need to incur penalties based on the remaining prob mass as well. # Specifically, the remaining prob mass will result in a penalty of 1.0, which is e^(0.0). none_probs = ( 1.0 - output.exp().sum(dim=-1, keepdims=True)).clamp(min=0.0) none_penalty = (1e-8 + none_probs).log().align_as(output) logits = torch.cat([logits, none_penalty], dim=-1) score = torch.logsumexp(logits, dim=-1).exp() elif g.weighted_loss == 'ot': if not self.training: raise RuntimeError('Cannot use OT for training.') probs = output.exp() # We have to incur penalties based on the remaining prob mass as well. none_probs = (1.0 - probs.sum(dim=-1, keepdims=True)).clamp( min=0.0) mat = torch.cat([ mat, get_tensor(torch.ones_like(none_probs.rename(None))) ], dim=-1) probs = torch.cat([probs, none_probs], dim=-1) score = (mat * probs).sum(dim=-1) else: raise ValueError(f'Cannot recognize {self.weighted_loss}.') scores[name] = (score, weight) return scores @not_supported_argument_value('new_style', True) def predict(self, batch, k=-1) -> Dict[Cat, Tuple[FT, LT, np.ndarray]]: """ Predict the top K results for each feature group. If k == -1, then everything would be sorted and returned, otherwise take the topk. """ ret = dict() distr = self(batch) for cat, log_probs in distr.items(): e = get_enum_by_cat(cat) name = cat.name.lower() max_k = log_probs.size(name) this_k = max_k if k == -1 else min(max_k, k) top_values, top_indices = log_probs.topk(this_k, dim=-1) top_cats = np.asarray([ e.get(i) for i in top_indices.view(-1).cpu().numpy() ]).reshape(*top_indices.shape) ret[name] = (top_values, top_indices, top_cats) return ret
class AdaptLM(LM): add_argument('use_prior', dtype=bool, default=False, msg='Flag to use prior.') add_argument('prior_value', dtype=float, default=0.5, msg='Value for prior.') add_argument('use_moe', dtype=bool, default=False, msg='Flag to use MoE.') @not_supported_argument_value('new_style', True) def __init__(self): super().__init__() saved_dict = torch.load(g.lm_model_path) try: self.load_state_dict(saved_dict['model']) except RuntimeError as e: logging.error(str(e)) # NOTE(j_luo) We have to map normal feature embedidngs to dense feature embeddings. old_weights = saved_dict['model'][ 'encoder.feat_embedding.embed_layer.weight'] for cat in Category: try: emb_param = self.encoder.feat_embedding.embed_layer[cat.name] e = get_enum_by_cat(cat) g_idx = [feat.value.g_idx for feat in e] emb_param.data.copy_(old_weights[g_idx]) except KeyError: pass freeze(self.encoder) freeze(self.predictor) self.adapter = AdaptLayer() if g.use_prior or g.use_moe: noise_hs = 10 noise_dim = 10 self.noise_encoder = self._get_encoder(hidden_size=noise_hs, dim=noise_dim) self.noise_predictor = Predictor(hidden_size=noise_hs) if g.use_moe: self.moe_gate = nn.Linear(noise_hs + g.hidden_size, 2) def forward(self, batch: DenseIpaBatch) -> AdaptLMReturn: sfm_adapted = self.adapter(batch.dense_feat_matrix) h = self.encoder(sfm_adapted, batch.pos_to_predict, batch.source_padding) distr = self.predictor(h) if g.use_prior: if g.use_moe: h_noise = self.noise_encoder(batch.dense_feat_matrix, batch.pos_to_predict, batch.source_padding) distr_noise = self.noise_predictor(h_noise) gate_logits = self.moe_gate(torch.cat([h, h_noise], dim=-1)) # / self._temp gate_log_probs = gate_logits.log_softmax(dim=-1) return AdaptLMReturn(distr, gate_logits, distr_noise) else: h_noise = self.noise_encoder(batch.dense_feat_matrix, batch.pos_to_predict, batch.source_padding) distr_noise = self.noise_predictor(h_noise) for cat in distr: d = distr[cat] d_noise = distr_noise[cat] stacked = torch.stack([ d + math.log(g.prior_value), d_noise + math.log(1.0 - g.prior_value) ], new_name='prior') new_d = stacked.logsumexp(dim='prior') distr[cat] = new_d return AdaptLMReturn(distr) def score(self, batch) -> AdaptLMReturn: ret = self(batch) return AdaptLMReturn(self.score_distr(ret.distr, batch), ret.gate_logits, self.score_distr(ret.distr_noise, batch))
from xib.data_loader import (ContinuousTextDataLoader, DataLoaderRegistry, DenseIpaDataLoader, IpaDataLoader) from xib.model.decipher_model import DecipherModel from xib.model.extract_model import ExtractModel from xib.model.lm_model import LM, AdaptLM from xib.search.search_solver import SearchSolver from xib.search.searcher import BruteForceSearcher from xib.training.evaluator import (DecipherEvaluator, ExtractEvaluator, LMEvaluator, SearchSolverEvaluator) from xib.training.task import (AdaptCbowTask, AdaptLMTask, CbowTask, DecipherTask, ExtractTask, LMTask, MlmTask, TransferTask) from xib.training.trainer import (AdaptLMTrainer, DecipherTrainer, ExtractTrainer, LMTrainer) add_argument('task', default='lm', dtype=str, choices=[ 'lm', 'cbow', 'adapt_lm', 'adapt_cbow', 'decipher', 'search', 'extract'], msg='which task to run') class BaseManager(ABC): @abstractmethod def run(self): ... class LMManager(BaseManager): model_cls = LM trainer_cls = LMTrainer task_cls = LMTask def __init__(self):
class DecipherModel(nn.Module): add_argument('adapt_mode', default='none', choices=['none'], dtype=str, msg='how to adapt the features from one language to another') add_argument('num_self_attn_layers', default=2, dtype=int, msg='number of self attention layers') add_argument('num_samples', default=100, dtype=int, msg='number of samples per sequence') add_argument('num_heads', default=4, dtype=int, msg='Number for heads for self attention.') add_argument('lm_model_path', dtype='path', msg='path to a pretrained lm model') add_argument('dropout', default=0.0, dtype=float, msg='dropout rate') add_argument('sampling_temperature', default=1.0, dtype=float, msg='Sampling temperature') add_argument( 'vocab_path', dtype='path', msg= 'Path to a vocabulary file which would provide word-level features to the model.' ) add_argument('use_brute_force', dtype=bool, default=False, msg='Use brute force searcher.') add_argument('n_times', dtype=int, default=5, msg='Number of neighbors.') @not_supported_argument_value('new_style', True) def __init__(self): super().__init__() self.lm_model = LM() saved_dict = torch.load(g.lm_model_path) self.lm_model.load_state_dict(saved_dict['model']) freeze(self.lm_model) # NOTE(j_luo) I'm keeping a separate embedding for label prediction. self.emb_for_label = FeatEmbedding('feat_emb_for_label', 'chosen_feat_group', 'char_emb') cat_dim = g.dim * self.emb_for_label.effective_num_feature_groups self.self_attn_layers = nn.ModuleList() for _ in range(g.num_self_attn_layers): self.self_attn_layers.append( TransformerLayer(cat_dim, g.num_heads, cat_dim, dropout=g.dropout)) self.positional_embedding = PositionalEmbedding(512, cat_dim) self.label_predictor = nn.Sequential( nn.Linear(cat_dim, cat_dim), nn.LeakyReLU(negative_slope=0.1), nn.Linear(cat_dim, 3) # BIO. ) self.label_predictor[0].refine_names('weight', ['hid_repr', 'self_attn_repr']) self.label_predictor[2].refine_names('weight', ['label', 'hid_repr']) # Use vocab feature if provided. self.vocab = None if g.vocab_path: with open(g.vocab_path, 'r', encoding='utf8') as fin: self.vocab = set(line.strip() for line in fin) searcher_cls = BruteForceSearcher if g.use_brute_force else BeamSearcher self.searcher = searcher_cls() self.phi_scorer = nn.Linear(5, 1) self.phi_scorer.refine_names('weight', ['score', 'feature']) def _adapt(self, packed_feat_matrix: LT) -> LT: if g.adapt_mode == 'none': return packed_feat_matrix else: raise NotImplementedError() def forward( self, batch: Union[ContinuousIpaBatch, IpaBatch]) -> DecipherModelReturn: # Get the samples of label sequences first. out = self.emb_for_label(batch.feat_matrix, batch.source_padding) positions = get_named_range(batch.feat_matrix.size('length'), name='length') pos_emb = self.positional_embedding(positions).align_as(out) out = out + pos_emb out = out.align_to('length', 'batch', 'char_emb') with NoName(out, batch.source_padding): for i, layer in enumerate(self.self_attn_layers): out = layer(out, src_key_padding_mask=batch.source_padding) state = out.refine_names('length', 'batch', ...) logits = self.label_predictor(state) label_log_probs = logits.log_softmax(dim='label') label_probs = label_log_probs.exp() # NOTE(j_luo) O is equivalent to None. mask = expand_as(batch.source_padding, label_probs) source = expand_as( get_tensor([0.0, 0.0, 1.0]).refine_names('label').float(), label_probs) label_probs = label_probs.rename(None).masked_scatter( mask.rename(None), source.rename(None)) label_probs = label_probs.refine_names('length', 'batch', 'label') if not self.training or (g.supervised and not g.train_phi): probs = DecipherModelProbReturn(label_log_probs, None) return DecipherModelReturn(state, probs, None, None, None, None, None) # ------------------ More info during training ----------------- # # Get the lm score. gold_tag_seqs = batch.gold_tag_seqs if g.supervised and g.train_phi else None samples, sample_log_probs = self.searcher.search( batch.lengths, label_log_probs, gold_tag_seqs=gold_tag_seqs) probs = DecipherModelProbReturn(label_log_probs, sample_log_probs) packed_words, scores = self._get_scores(samples, batch.segments, batch.lengths, batch.feat_matrix, batch.source_padding) if g.supervised and g.train_phi: return DecipherModelReturn(state, probs, packed_words, None, scores, None, None) # ------------------- Contrastive estimation ------------------- # ptb_segments = list() duplicates = list() for segment in batch.segments: _ptb_segments, _duplicates = segment.perturb_n_times(g.n_times) # NOTE(j_luo) Ignore the first one. ptb_segments.extend(_ptb_segments[1:]) duplicates.extend(_duplicates[1:]) # ptb_segments = [segment.perturb_n_times(5) for segment in batch.segments] ptb_feat_matrix = [segment.feat_matrix for segment in ptb_segments] ptb_feat_matrix = torch.nn.utils.rnn.pad_sequence(ptb_feat_matrix, batch_first=True) ptb_feat_matrix.rename_('batch', 'length', 'feat_group') samples = samples.align_to('batch', ...) with NoName(samples, batch.lengths, batch.source_padding): ptb_samples = torch.repeat_interleave(samples, g.n_times * 2, dim=0) ptb_lengths = torch.repeat_interleave(batch.lengths, g.n_times * 2, dim=0) ptb_source_padding = torch.repeat_interleave(batch.source_padding, g.n_times * 2, dim=0) ptb_samples.rename_(*samples.names) ptb_lengths.rename_('batch') ptb_source_padding.rename_('batch', 'length') ptb_packed_words, ptb_scores = self._get_scores( ptb_samples, ptb_segments, ptb_lengths, ptb_feat_matrix, ptb_source_padding) ret = DecipherModelReturn(state, probs, packed_words, ptb_packed_words, scores, ptb_scores, duplicates) return ret def _get_scores( self, samples: LT, segments: Sequence[SegmentWindow], lengths: LT, feat_matrix: LT, source_padding: BT ) -> Tuple[PackedWords, DecipherModelScoreReturn]: bs = len(segments) segment_list = None if self.vocab is not None: segment_list = [segment.segment_list for segment in segments] packed_words = self.pack(samples, lengths, feat_matrix, segments, segment_list=segment_list) packed_words.word_feat_matrices = self._adapt( packed_words.word_feat_matrices) try: lm_batch = self._prepare_batch( packed_words ) # TODO(j_luo) This is actually continous batching. scores = self._get_lm_scores(lm_batch) nlls = list() for cat, (nll, weight) in scores.items(): if should_include(g.feat_groups, cat): nlls.append(nll * weight) # nlls = sum(nlls) nlls = sum(nlls) / lm_batch.lengths bw = packed_words.word_lengths.size('batch_word') p = packed_words.word_positions.size('position') nlls = nlls.unflatten('batch', [('batch_word', bw), ('position', p)]) nlls = nlls.sum(dim='position') lm_score, in_vocab_score = self._unpack(nlls, packed_words, bs) except EmptyPackedWords: lm_score = get_zeros(bs, packed_words.num_samples) in_vocab_score = get_zeros(bs, packed_words.num_samples) word_score = self._get_word_score(packed_words, bs) readable_score, unreadable_score = self._get_readable_scores( source_padding, samples) scores = [ lm_score, word_score, in_vocab_score, readable_score, unreadable_score ] features = torch.stack(scores, new_name='feature') phi_score = self.phi_scorer(features).squeeze('score') # if g.search: # samples = samples.align_to('length', 'batch', 'sample') # flat_samples = samples.flatten(['batch', 'sample'], 'batch_X_sample') # flat_sample_embeddings = self.tag_embedding(flat_samples) # bxs = flat_samples.size('batch_X_sample') # h0 = get_zeros([1, bxs, 100]) # c0 = get_zeros([1, bxs, 100]) # with NoName(flat_sample_embeddings): # output, (hn, _) = self.tag_lstm(flat_sample_embeddings, (h0, c0)) # tag_score = self.tag_scorer(hn).squeeze(dim=0).squeeze(dim=-1) # tag_score = tag_score.view(samples.size('batch'), samples.size('sample')) # ret['tag_score'] = tag_score.rename('batch', 'sample') scores = DecipherModelScoreReturn(lm_score, word_score, in_vocab_score, readable_score, unreadable_score, phi_score) return packed_words, scores @deprecated def _get_word_score(self, packed_words: PackedWords, batch_size: int) -> FT: with torch.no_grad(): num_words = get_zeros(batch_size * packed_words.num_samples) bi = packed_words.batch_indices si = packed_words.sample_indices idx = (bi * packed_words.num_samples + si).rename(None) inc = get_zeros( packed_words.batch_indices.size('batch_word')).fill_(1.0) # TODO(j_luo) add scatter_add_ to named_tensor module num_words.scatter_add_(0, idx, inc) num_words = num_words.view(batch_size, packed_words.num_samples).refine_names( 'batch', 'sample') return num_words def _get_lm_scores(self, lm_batch: IpaBatch) -> Dict[Category, FT]: max_size = min(100000, lm_batch.batch_size) with torch.no_grad(): batches = lm_batch.split(max_size) all_scores = [self.lm_model.score(batch) for batch in batches] cats = all_scores[0].keys() all_scores = { cat: list(zip(*[scores[cat] for scores in all_scores])) for cat in cats } for cat in cats: scores, weights = all_scores[cat] scores = torch.cat(scores, names='batch', new_name='batch') weights = torch.cat(weights, names='batch', new_name='batch') all_scores[cat] = (scores, weights) return all_scores @staticmethod def _prepare_batch(packed_words: PackedWords) -> IpaBatch: # TODO(j_luo) ugly try: return IpaBatch(None, packed_words.word_lengths.rename(None), packed_words.word_feat_matrices.rename(None), batch_name='batch', length_name='length').cuda() except RuntimeError: raise EmptyPackedWords() def pack(self, samples: LT, lengths: LT, feat_matrix: LT, segments: np.ndarray, segment_list: Optional[List[List[str]]] = None) -> PackedWords: with torch.no_grad(): feat_matrix = feat_matrix.align_to('batch', 'length', 'feat_group') samples = samples.align_to('batch', 'sample', 'length').int() ns = samples.size('sample') lengths = lengths.align_to('batch', 'sample').expand(-1, ns).int() batch_indices, sample_indices, word_positions, word_lengths, is_unique = extract_words( samples.cpu().numpy(), lengths.cpu().numpy(), num_threads=4) in_vocab = np.zeros_like(batch_indices, dtype=np.bool) if self.vocab is not None: in_vocab = check_in_vocab(batch_indices, word_positions, word_lengths, segment_list, self.vocab, num_threads=4) in_vocab = get_tensor(in_vocab).refine_names( 'batch_word').bool() batch_indices = get_tensor(batch_indices).refine_names( 'batch_word').long() sample_indices = get_tensor(sample_indices).refine_names( 'batch_word').long() word_positions = get_tensor(word_positions).refine_names( 'batch_word', 'position').long() word_lengths = get_tensor(word_lengths).refine_names( 'batch_word').long() is_unique = get_tensor(is_unique).refine_names('batch', 'sample').bool() key = (batch_indices.align_as(word_positions).rename(None), word_positions.rename(None)) word_feat_matrices = feat_matrix.rename(None)[key] word_feat_matrices = word_feat_matrices.refine_names( 'batch_word', 'position', 'feat_group') packed_words = PackedWords(word_feat_matrices, word_lengths, batch_indices, sample_indices, word_positions, is_unique, ns, segments, in_vocab=in_vocab) return packed_words def _unpack(self, nlls: FT, packed_words: PackedWords, batch_size: int) -> Tuple[FT, FT]: with torch.no_grad(): lm_loss = get_zeros(batch_size * packed_words.num_samples) bi = packed_words.batch_indices si = packed_words.sample_indices idx = (bi * packed_words.num_samples + si).rename(None) # TODO(j_luo) ugly lm_loss.scatter_add_(0, idx, nlls.rename(None)) lm_loss = lm_loss.view(batch_size, packed_words.num_samples).refine_names( 'batch', 'sample') in_vocab_score = get_zeros(batch_size * packed_words.num_samples) if self.vocab is not None: in_vocab_score.scatter_add_( 0, idx, packed_words.in_vocab.float().rename(None)) in_vocab_score = in_vocab_score.view( batch_size, packed_words.num_samples).refine_names('batch', 'sample') return -lm_loss, in_vocab_score # NOTE(j_luo) NLL are losses, not scores. @deprecated def _sample(self, label_probs: FT, sampling_probs: FT, source_padding: FT, gold_tag_seqs: Optional[FT] = None) -> Tuple[LT, FT]: """Return samples based on `label_probs`.""" # Ignore padded indices. label_probs = label_probs.align_to('batch', 'length', 'label') sampling_probs = sampling_probs.align_to('batch', 'length', 'label') source_padding = source_padding.align_to('batch', 'length') # Get packed batches. label_distr = Categorical(probs=sampling_probs.rename(None)) label_samples = label_distr.sample([g.num_samples]).refine_names( 'sample', 'batch', 'length') label_samples = label_samples.align_to('batch', 'sample', 'length') # Add the ground truth if needed. if gold_tag_seqs is not None: gold_tag_seqs = gold_tag_seqs.align_as(label_samples) all_other_tag_seqs = torch.full_like(gold_tag_seqs, O) label_samples = torch.cat( [gold_tag_seqs, all_other_tag_seqs, label_samples], dim='sample') batch_idx = get_named_range( label_samples.size('batch'), 'batch').align_as(label_samples).rename(None) length_idx = get_named_range( label_samples.size('length'), 'length').align_as(label_samples).rename(None) label_sample_probs = label_probs.rename(None)[ batch_idx, length_idx, label_samples.rename(None)] label_sample_probs = label_sample_probs.refine_names( *label_samples.names) label_sample_log_probs = (1e-8 + label_sample_probs).log() label_sample_log_probs = ( (~source_padding).align_as(label_sample_log_probs).float() * label_sample_log_probs).sum(dim='length') return label_samples, label_sample_log_probs def _get_readable_scores(self, source_padding: BT, samples: LT) -> Tuple[FT, FT]: samples = samples.align_to('batch', 'sample', 'length') source_padding = source_padding.align_as(samples) is_part_of_word = ((samples == B) | (samples == I)) & ~source_padding not_part_of_word = (samples == O) & ~source_padding readable_score = is_part_of_word.float().sum(dim='length') unreadable_score = not_part_of_word.float().sum(dim='length') return readable_score, unreadable_score
class ExtractTrainer(BaseTrainer): model: ExtractModel add_argument('reg_hyper', default=1.0, dtype=float, msg='Hyperparameter for alignment regularization.') add_argument('save_alignment', default=False, dtype=bool, msg='Flag to save alignment every step.') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.analyzer = ExtractAnalyzer() self.ins_del_cost = g.init_ins_del_cost if g.save_alignment: self.add_callback('total_step', 1, self.save_alignment) def save_alignment(self): to_save = { 'unit_aligner': self.model.g2p.unit_aligner.state_dict(), } path = g.log_dir / f'saved.{self.stage}.almt' torch.save(to_save, path) logging.imp(f'Alignment saved to {path}.') @global_property def ins_del_cost(self): pass @ins_del_cost.setter def ins_del_cost(self, value): logging.imp(f'Setting ins_del_cost to {value}.') @global_property def threshold(self): pass @threshold.setter def threshold(self, value): logging.imp(f'Setting threshold to {value}.') def add_trackables(self): self.tracker.add_trackable('round', endless=True) self.tracker.add_trackable('total_step', total=g.num_steps) self.tracker.add_max_trackable('best_f1') def reset(self): """Reset the tracker. But keep the best_f1 since it's related to evaluation.""" self.tracker.reset('total_step') def load(self, path: Path): saved = torch.load(path) smsd = saved['model'] self.model.load_state_dict(smsd) logging.imp(f'Loading model from {path}.') def save(self, eval_metrics: Metrics): self.save_to(g.log_dir / f'saved.{self.stage}.latest') if eval_metrics is not None: if self.tracker.update('best_f1', value=eval_metrics.prf_exact_span_f1.value): out_path = g.log_dir / f'saved.{self.stage}.best' logging.warning('Do NOT use this number since this f1 is compared against ground truths.') logging.imp(f'Best model updated: new best is {self.tracker.best_f1:.3f}') self.save_to(out_path) def should_terminate(self): return self.tracker.is_finished('total_step') def train_one_step(self, dl: ContinuousTextDataLoader) -> Metrics: self.model.train() self.optimizer.zero_grad() accum_metrics = Metrics() for _ in pbar(range(g.accum_gradients), desc='accum_gradients'): batch = dl.get_next_batch() ret = self.model(batch) metrics = self.analyzer.analyze(ret, batch) loss = -metrics.ll.mean try: loss = loss + metrics.reg.mean * g.reg_hyper except AttributeError: pass loss_per_split = loss / g.accum_gradients loss_per_split.backward() accum_metrics += metrics grad_norm = clip_grad_norm_(self.model.parameters(), 5.0) self.optimizer.step() accum_metrics += Metric('grad_norm', grad_norm * batch.batch_size, batch.batch_size) return accum_metrics
smsd = {k: v for k, v in smsd.items() if not k.startswith('phi_scorer')} self.model.load_state_dict(smsd, strict=False) if load_optimizer_state: self.optimizer.load_state_dict(saved['optimizer']) logging.imp(f'Loading model from {path}.') def save(self, eval_metrics: Metrics): self.save_to(g.log_dir / 'saved.latest') # self.tracker.update('best_loss', value=eval_metrics.dev_total_loss.mean) if self.tracker.update('best_f1', value=eval_metrics.dev_prf_f1.value): out_path = g.log_dir / f'saved.best' logging.imp(f'Best model updated: new best is {self.tracker.best_f1:.3f}') self.save_to(out_path) add_argument('accum_gradients', default=1, dtype=int, msg='Accumulate this many steps of gradients.') class ExtractTrainer(BaseTrainer): model: ExtractModel add_argument('reg_hyper', default=1.0, dtype=float, msg='Hyperparameter for alignment regularization.') add_argument('save_alignment', default=False, dtype=bool, msg='Flag to save alignment every step.') def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.analyzer = ExtractAnalyzer() self.ins_del_cost = g.init_ins_del_cost if g.save_alignment: self.add_callback('total_step', 1, self.save_alignment)