Пример #1
0
class DecipherTrainer(LMTrainer):

    add_argument('score_per_word', default=1.0, dtype=float, msg='score added for each word')
    add_argument('concentration', default=1e-2, dtype=float, msg='concentration hyperparameter')

    def __init__(self, model: 'a', train_data_loader: 'a', num_steps, learning_rate, check_interval, save_interval, log_dir, feat_groups, score_per_word: 'p', concentration: 'p'):
        super().__init__(model, train_data_loader, num_steps, learning_rate, check_interval, save_interval, log_dir, feat_groups)

    def train_loop(self) -> Metrics:
        self.model.train()
        self.optimizer.zero_grad()
        batch = next(self.iterator)
        ret = self.model(batch)
        bs = batch.feat_matrix.size('batch')
        breakpoint()  # DEBUG(j_luo)
        modified_log_probs = ret['sample_log_probs'] * self.concentration + (~ret['is_unique']).float() * (-999.9)
        sample_probs = modified_log_probs.log_softmax(dim='sample').exp()
        final_ret = ret['lm_score'] + ret['word_score'] * self.score_per_word
        score = (sample_probs * final_ret).sum()
        lm_score = Metric('lm_score', ret['lm_score'].sum(), bs)
        word_score = Metric('word_score', ret['word_score'].sum(), bs)
        score = Metric('score', score, bs)
        metrics = Metrics(score, lm_score, word_score)
        loss = -score.mean
        loss.backward()
        self.optimizer.step()
        return metrics
Пример #2
0
class MetricLearningDataLoader(PandasDataLoader):

    add_argument('family_file_path', dtype='path', msg='path to the family file')
    add_argument('num_lang_pairs', dtype=int, default=10, msg='number of languages')

    def __init__(self, data_path, num_workers, feat_groups: 'p', family_file_path: 'p', num_lang_pairs: 'p', data=None):
        if data is None:
            data = _get_metric_data(data_path, feat_groups, family_file_path)
        self.all_langs = sorted(set(data['lang1']))
        self.cats = [cat.name for cat in Category if should_include(feat_groups, cat)] + ['avg']
        super().__init__(data, batch_size=num_lang_pairs, num_workers=num_workers)

    def __iter__(self) -> MetricLearningBatch:
        for df in super().__iter__():
            lang1 = df['lang1'].values
            lang2 = df['lang2'].values
            normalized_score = get_tensor(df[self.cats].values.astype('float32'))
            dist = get_tensor(df['dist'].values.astype('float32'))
            return MetricLearningBatch(lang1, lang2, normalized_score, dist).cuda()

    def select(self, langs1: Sequence[str], langs2: Sequence[str]) -> 'MetricLearningDataLoader':
        all_langs1 = set(langs1)
        all_langs2 = set(langs2)
        data = self.dataset.data
        mask = (data['lang1'].isin(all_langs1)) & (data['lang2'].isin(all_langs2))
        data = data[mask].reset_index(drop=True)
        return MetricLearningDataLoader(data=data)
Пример #3
0
class BaseTrainer(Trainer, metaclass=ABCMeta):

    add_argument('num_steps', default=10, dtype=int, msg='number of steps to train')
    add_argument('learning_rate', default=2e-3, dtype=float, msg='learning rate')
    add_argument('check_interval', default=2, dtype=int, msg='check metrics after this many steps')
    add_argument('save_interval', default=500, dtype=int, msg='save models after this many steps')

    def __init__(self, model: 'a', train_data_loader: 'a', num_steps, learning_rate, check_interval, save_interval, log_dir, feat_groups):
        super().__init__()
        self.tracker.add_track('step', update_fn='add', finish_when=num_steps)
        self.optimizer = optim.Adam(get_trainable_params(self.model, named=False), learning_rate)

        self.init_params()

        # Prepare batch iterator.
        self.iterator = self._next_batch_iterator()

    def init_params(self):
        self._init_params()

    @log_this(log_level='IMP')
    def _init_params(self, init_matrix=True, init_vector=False, init_higher_tensor=False):
        for name, p in get_trainable_params(self.model, named=True):
            if p.dim() == 2 and init_matrix:
                nn.init.xavier_uniform_(p)
            elif p.dim() == 1 and init_vector:
                nn.init.uniform_(p, 0.01)
            elif init_higher_tensor:
                nn.init.uniform_(p, 0.01)

    def _next_batch_iterator(self):
        while True:
            yield from self.train_data_loader

    @property
    @abstractmethod
    def track(self):
        pass

    def check_metrics(self, accum_metrics: Metrics):
        if self.track % self.check_interval == 0:
            logging.info(accum_metrics.get_table(f'Step: {self.track}'))
            accum_metrics.clear()

    def save(self):
        if self.track % self.save_interval == 0:
            out_path = self.log_dir / 'saved.latest'
            self._save(out_path)

    def _save(self, path: Path):
        to_save = {
            'model': self.model.state_dict(),
            'g': g.state_dict()
        }
        torch.save(to_save, path)
        logging.imp(f'Model saved to {path}.')
Пример #4
0
class MetricLearningModel(nn.Module):

    add_argument('num_layers',
                 default=1,
                 dtype=int,
                 msg='number of trainable layers.')

    def __init__(self, hidden_size, feat_groups, num_layers):
        super().__init__()
        effective_num_feat_groups = len(get_effective_c_idx(
            feat_groups)) + 1  # NOTE(j_luo) +1 due to 'avg' score.
        if num_layers == 1:
            self.regressor = nn.Linear(effective_num_feat_groups, 1)
        else:
            modules = [
                nn.Linear(effective_num_feat_groups, hidden_size),
                nn.LeakyReLU(negative_slope=0.1)
            ]
            for _ in range(num_layers - 2):
                modules.append(nn.Linear(hidden_size, hidden_size))
                modules.append(nn.LeakyReLU(negative_slope=0.1))
            modules.append(nn.Linear(hidden_size, 1))
            self.regressor = nn.Sequential(*modules)

    def forward(self, batch: MetricLearningBatch) -> torch.FloatTensor:
        output = self.regressor(batch.normalized_score.rename(None)).view(-1)
        return output
Пример #5
0
class LMTrainer(BaseLMRunner, BaseTrainer):

    add_argument('feat_groups', default='pcvdst', dtype=str,
                 msg='what to include during training: p(type), c(onstonant), v(vowel), d(iacritics), s(tress) and t(one).')

    def __init__(self, model: 'a', train_data_loader: 'a', evaluator: 'a', num_steps, learning_rate, check_interval, save_interval, log_dir, feat_groups):
        BaseTrainer.__init__(self, model, train_data_loader)
        self.best_metrics: Metrics = None

    def train_loop(self) -> Metrics:
        self.model.train()
        self.optimizer.zero_grad()
        batch = next(self.iterator)
        scores = self.model.score(batch)
        metrics = self.analyze_scores(scores)
        metrics.loss.mean.backward()
        grad_norm = get_grad_norm(self.model)
        grad_norm = Metric('grad_norm', grad_norm * len(batch), len(batch))
        metrics += grad_norm
        self.optimizer.step()
        return metrics

    def save(self):
        super().save()
        if self.track % self.save_interval == 0:
            metrics = self.evaluator.evaluate()
            logging.info(f'New evaluation metrics is {metrics.loss.mean:.3f}.')
            if self.best_metrics is None or metrics.loss.mean < self.best_metrics.loss.mean:
                self.best_metrics = metrics
                out_path = self.log_dir / 'saved.best'
                logging.imp(f'Best model updated: new best is {self.best_metrics.loss.mean:.3f}.')
                self._save(out_path)
Пример #6
0
class BaseIpaDataLoader(DataLoader, metaclass=ABCMeta):

    add_argument('data_path', dtype='path', msg='path to the feat data in tsv format.')
    add_argument('num_workers', default=5, dtype=int, msg='number of workers for the data loader')
    add_argument('char_per_batch', default=500, dtype=int, msg='batch_size')
    add_argument('new_style', default=False, dtype=bool, msg='flag to use new style ipa annotations')

    def __init__(self, data_path: 'p', char_per_batch: 'p', num_workers, feat_groups: 'p'):
        dataset = IpaDataset(data_path)
        batch_sampler = BatchSampler(dataset, char_per_batch, shuffle=True)
        cls = type(self)
        super().__init__(dataset, batch_sampler=batch_sampler,
                         num_workers=num_workers, collate_fn=collate_fn)

    @abstractmethod
    def _prepare_batch(self):
        pass

    def __iter__(self):
        for collate_return in super().__iter__():
            batch = self._prepare_batch(collate_return)
            yield batch
Пример #7
0
class Encoder(nn.Module):

    add_argument('window_size', default=3, dtype=int, msg='window size for the cnn kernel')
    add_argument('dense_input', default=False, dtype=bool, msg='flag to dense input feature matrices')

    def __init__(self, num_features, dim, window_size, hidden_size, feat_groups, dense_input):
        super().__init__()

        emb_cls = DenseFeatEmbedding if dense_input else FeatEmbedding
        self.feat_embedding = emb_cls('feat_emb', 'chosen_feat_group', 'char_emb')
        self.cat_dim = dim * self.feat_embedding.effective_num_feature_groups
        # IDEA(j_luo) should I define a Rename layer?
        self.conv_layers = nn.Sequential(
            nn.Conv1d(self.cat_dim, self.cat_dim, self.window_size, padding=self.window_size // 2)
        )
        self.linear = nn.Linear(self.cat_dim, self.hidden_size)

    def forward(self, feat_matrix, pos_to_predict, source_padding):
        bs = source_padding.size('batch')
        l = source_padding.size('length')
        feat_emb = self.feat_embedding(feat_matrix, source_padding)
        feat_emb = feat_emb.align_to('batch', 'char_emb', 'length')
        # feat_emb = self.feat_embeddings(feat_matrix).view(bs, l, -1).transpose(1, 2)  # size: bs x D x l
        batch_i = get_range(bs, 1, 0)
        # TODO(j_luo) ugly
        feat_emb.rename(None)[batch_i, :, pos_to_predict.rename(None)] = 0.0
        output = self.conv_layers(feat_emb.rename(None))
        output = output.refine_names('batch', 'char_conv_repr', 'length')  # size: bs x D x l
        output = self.linear(output.align_to(..., 'char_conv_repr'))  # size: bs x l x n_hid
        output = output.refine_names('batch', 'length', 'hidden_repr')
        output = leaky_relu(output, negative_slope=0.1)
        # NOTE(j_luo) This is actually quite wasteful because we are discarding all the irrelevant information, which is computed anyway. This is equivalent to training on ngrams.
        # TODO(j_luo) ugly
        h = output.rename(None)[batch_i, pos_to_predict.rename(None)]
        h = h.refine_names('batch', 'hidden_repr')  # size: bs x n_hid
        return h
Пример #8
0
class MetricLearningManager(Manager):

    add_argument('k_fold', default=10, dtype=int, msg='number of folds for cross validation')

    def __init__(self, k_fold, random_seed, data_path, feat_groups, family_file_path):
        self.model = MetricLearningModel()
        self.data_loader = MetricLearningDataLoader()
        if os.environ.get('CUDA_VISIBLE_DEVICES', False):
            self.model.cuda()
        self.trainer = MetricLearningTrainer(self.model, self.data_loader)
        self.evaluator = Evaluator(self.model, self.data_loader)

    def train(self):
        set_random_seeds(self.random_seed)
        all_langs = self.data_loader.all_langs
        num_langs = len(all_langs)
        idx = list(range(num_langs))
        random.shuffle(idx)

        num_langs_per_fold = (num_langs + self.k_fold - 1) // self.k_fold

        accum_metrics = Metrics()
        for fold in range(self.k_fold):
            # Get train-dev split.
            start_idx = fold * num_langs_per_fold
            end_idx = start_idx + num_langs_per_fold if fold < self.k_fold - 1 else num_langs
            dev_langs = [all_langs[idx[i]] for i in range(start_idx, end_idx)]
            logging.imp(f'dev_langs: {sorted(dev_langs)}')
            train_langs = [all_langs[idx[i]] for i in range(num_langs) if i < start_idx or i >= end_idx]
            assert len(set(dev_langs + train_langs)) == num_langs

            self.trainer.reset()
            best_mse = self.trainer.train(
                self.evaluator,
                train_langs,
                dev_langs,
                fold)

            # Aggregate every fold.
            accum_metrics += best_mse

        logging.info(accum_metrics.get_table())
Пример #9
0
class GraphPredictor(nn.Module):
    """Based on the word-level representations, predict the edge types and the edge strengths (or edge norms)."""

    add_argument(
        'edge_norm_agg',
        default='sum',
        choices=['sum', 'mean'],
        dtype=str,
        msg='how to aggregate the attention scores to get an edge norm.')

    def __init__(self, emb_dim, dropout, num_relations, edge_norm_agg):
        super().__init__()
        self.norm_attn = MultiHeadAttention(8, emb_dim, dropout=dropout)
        self.type_attn = MultiHeadAttention(8, emb_dim, dropout=dropout)
        self.type_proj = nn.Linear(emb_dim, num_relations)

    def forward(self, h_in, word_mask):
        # Get norms first.
        norm_output, norm_attn_weights = self.norm_attn(h_in,
                                                        word_mask,
                                                        return_weights=True)
        # NOTE(j_luo) Aggregate attention scores to get an edge norm.
        if self.edge_norm_agg == 'sum':
            norm_attn_weights = norm_attn_weights.sum(dim=1)
        else:
            norm_attn_weights = norm_attn_weights.mean(dim=1)

        norms = norm_attn_weights + norm_attn_weights.transpose(
            1, 2)  # Symmetrize attention weights.

        # Get types now.
        type_output, type_attn_weights = self.type_attn(h_in,
                                                        word_mask,
                                                        return_weights=True)
        # I'm decomposing the prediction.
        type_logits = self.type_proj(type_output)  # bs x wl x nr
        type_logits = type_logits.unsqueeze(dim=2) + type_logits.unsqueeze(
            dim=1)
        type_probs = torch.log_softmax(type_logits, dim=-1).exp()

        return norms, type_probs
Пример #10
0
class ContinuousRGCN(RGCNConv):
    """Similar to RGCN but with continous `edge_type` and `edge_norm`."""

    add_argument('num_bases',
                 default=5,
                 dtype=int,
                 msg='number of bases for RGCN')

    def __init__(self, emb_dim, num_bases, num_relations):
        super().__init__(emb_dim, emb_dim, num_bases, num_relations)
        self.basis = nn.Parameter(self.basis.transpose(
            0, 1).contiguous())  # NOTE(j_luo) This would help in message.

    def message(self, x_j, edge_index_j, edge_type, edge_norm):
        """The original algorithm consumes way too much memory. But we can use simple arithmetics to help.

        Let each basic be B(j), where j in {1..M}, and each relation r has its relation weight a(r, j), where r in {1...R}.
        Each edge e has its own relation distribution p(e, r) for e in E.

        Now for each relation r, the projection weight W(r) = sum_{j=1-J} a(r, j) * B(j).
        For each edge e, the projection weight is W(e) = sum_{r=1-R} p(e, r) * W(r).
        The output now is :
                    h'(e) = h(e) @ W(e)
                          = h(e) @ (sum_{r=1-R} p(e, r) * W(r))
                          = h(e) @ (sum_{r=1-R, j=1-J} p(e, r) * a(r, j) * B(j))
                          = sum_{r=1-R, j=1-J} [p(e, r) * a(r, j)] * [h(e) @ B(j)]
                          = sum_{j=1-J} c(e, j) * [h(e) @ B(j)],
        where:
                  c(e, j) = sum_{r=1-R} p(e, r) * a(r, j)

        """
        E, _ = x_j.shape
        h_e_basis = x_j @ self.basis.view(self.in_channels,
                                          self.num_bases * self.out_channels)
        h_e_basis = h_e_basis.view(E, self.num_bases, self.out_channels)

        weight = edge_type @ self.att  # size: E x nr @ nr x nb -> E x nb
        # size: E x 1 x nb @ E x nb x n_out -> E x 1 x n_out -> E x n_out
        out = (weight.unsqueeze(dim=1) @ h_e_basis).squeeze(dim=-2)
        return out * edge_norm.view(-1, 1)
Пример #11
0
def add_main_arguments():
    """
    Generate a parameters parser.
    """
    # main parameters
    add_argument("dump_path",
                 dtype=str,
                 default="./dumped/",
                 msg="Experiment dump path")
    add_argument("exp_name", dtype=str, default="", msg="Experiment name")
    add_argument(
        "save_periodic_epoch",
        dtype=int,
        default=0,
        msg="Save the model periodically every few epochs (0 to disable)")
    add_argument(
        "save_periodic_step",
        dtype=int,
        default=0,
        msg="Save the model periodically every few steps (0 to disable)")
    add_argument("eval_interval",
                 dtype=int,
                 default=0,
                 msg="evaluate the model every few steps (0 to disable)")
    add_argument("exp_id", dtype=str, default="", msg="Experiment ID")
    add_argument("log_level", dtype=str, default="INFO", msg="log level")

    # float16 / AMP API
    add_argument("fp16",
                 dtype=bool,
                 default=False,
                 msg="Run model with float16")
    add_argument(
        "amp",
        dtype=int,
        default=-1,
        msg=
        "Use AMP wrapper for float16 / distributed / gradient accumulation. Level of optimization. -1 to disable."
    )

    # only use an encoder (use a specific decoder for machine translation)
    add_argument("encoder_only",
                 dtype=bool,
                 default=True,
                 msg="Only use an encoder")

    # model parameters
    add_argument("emb_dim", dtype=int, default=512, msg="Embedding layer size")
    add_argument("n_layers",
                 dtype=int,
                 default=4,
                 msg="Number of Transformer layers")
    add_argument("n_heads",
                 dtype=int,
                 default=8,
                 msg="Number of Transformer heads")
    add_argument("dropout", dtype=float, default=0, msg="Dropout")
    add_argument("attention_dropout",
                 dtype=float,
                 default=0,
                 msg="Dropout in the attention layer")
    add_argument("gelu_activation",
                 dtype=bool,
                 default=False,
                 msg="Use a GELU activation instead of ReLU")
    add_argument("share_inout_emb",
                 dtype=bool,
                 default=True,
                 msg="Share input and output embeddings")
    add_argument("sinusoidal_embeddings",
                 dtype=bool,
                 default=False,
                 msg="Use sinusoidal embeddings")
    add_argument("use_lang_emb",
                 dtype=bool,
                 default=True,
                 msg="Use language embedding")
    add_argument("use_graph",
                 dtype=bool,
                 default=False,
                 msg="Use a graph formulation on top of transformer encoder")

    # adaptive softmax
    add_argument("asm", dtype=bool, default=False, msg="Use adaptive softmax")
    add_argument("asm_cutoffs",
                 dtype=str,
                 default="8000,20000",
                 msg="Adaptive softmax cutoffs")
    add_argument("asm_div_value",
                 dtype=float,
                 default=4,
                 msg="Adaptive softmax cluster sizes ratio")

    # causal language modeling task parameters
    add_argument(
        "context_size",
        dtype=int,
        default=0,
        msg=
        "Context size (0 means that the first elements in sequences won't have any context)"
    )

    # masked language modeling task parameters
    add_argument(
        "word_pred",
        dtype=float,
        default=0.15,
        msg="Fraction of words for which we need to make a prediction")
    add_argument(
        "sample_alpha",
        dtype=float,
        default=0,
        msg=
        "Exponent for transforming word counts to probabilities (~word2vec sampling)"
    )
    add_argument(
        "word_mask_keep_rand",
        dtype=str,
        default="0.8,0.1,0.1",
        msg=
        "Fraction of words to mask out / keep / randomize, among the words to predict"
    )

    # input sentence noise
    add_argument("word_shuffle",
                 dtype=float,
                 default=0,
                 msg="Randomly shuffle input words (0 to disable)")
    add_argument("word_dropout",
                 dtype=float,
                 default=0,
                 msg="Randomly dropout input words (0 to disable)")
    add_argument("word_blank",
                 dtype=float,
                 default=0,
                 msg="Randomly blank input words (0 to disable)")

    # data
    add_argument('input_format',
                 dtype=str,
                 default='plain',
                 choices=['plain', 'eat', 'neo_linear'])
    add_argument("data_path", dtype='path', default="", msg="Data path")
    add_argument("lgs",
                 dtype=str,
                 default="",
                 msg="Languages (lg1-lg2-lg3 .. ex: en-fr-es-de)")
    add_argument("max_vocab",
                 dtype=int,
                 default=-1,
                 msg="Maximum vocabulary size (-1 to disable)")
    add_argument("min_count",
                 dtype=int,
                 default=0,
                 msg="Minimum vocabulary count")
    add_argument("lg_sampling_factor",
                 dtype=float,
                 default=-1,
                 msg="Language sampling factor")

    # batch parameters
    add_argument("bptt", dtype=int, default=256, msg="Sequence length")
    add_argument("max_len",
                 dtype=int,
                 default=100,
                 msg="Maximum length of sentences (after BPE)")
    add_argument("group_by_size",
                 dtype=bool,
                 default=True,
                 msg="Sort sentences by size during the training")
    add_argument("batch_size",
                 dtype=int,
                 default=32,
                 msg="Number of sentences per batch")
    add_argument(
        "max_batch_size",
        dtype=int,
        default=0,
        msg=
        "Maximum number of sentences per batch (used in combination with tokens_per_batch, 0 to disable)"
    )
    add_argument("tokens_per_batch",
                 dtype=int,
                 default=-1,
                 msg="Number of tokens per batch")

    # training parameters
    add_argument("split_data",
                 dtype=bool,
                 default=False,
                 msg="Split data across workers of a same node")
    add_argument("optimizer",
                 dtype=str,
                 default="adam,lr=0.0001",
                 msg="Optimizer (SGD / RMSprop / Adam, etc.)")
    add_argument("clip_grad_norm",
                 dtype=float,
                 default=5,
                 msg="Clip gradients norm (0 to disable)")
    add_argument(
        "epoch_size",
        dtype=int,
        default=100000,
        msg="Epoch size / evaluation frequency (-1 for parallel data size)")
    add_argument("max_epoch",
                 dtype=int,
                 default=100000,
                 msg="Maximum epoch size")
    add_argument(
        "stopping_criterion",
        dtype=str,
        default="",
        msg=
        "Stopping criterion, and number of non-increase before stopping the experiment"
    )
    add_argument("validation_metrics",
                 dtype=str,
                 default="",
                 msg="Validation metrics")
    add_argument(
        "accumulate_gradients",
        dtype=int,
        default=1,
        msg=
        "Accumulate model gradients over N iterations (N times larger batch sizes)"
    )

    # training coefficients
    add_argument("lambda_mlm",
                 dtype=str,
                 default="1",
                 msg="Prediction coefficient (MLM)")
    add_argument("lambda_clm",
                 dtype=str,
                 default="1",
                 msg="Causal coefficient (LM)")
    add_argument("lambda_pc", dtype=str, default="1", msg="PC coefficient")
    add_argument("lambda_ae", dtype=str, default="1", msg="AE coefficient")
    add_argument("lambda_mt", dtype=str, default="1", msg="MT coefficient")
    add_argument("lambda_bt", dtype=str, default="1", msg="BT coefficient")
    add_argument("lambda_ep", dtype=str, default="1", msg="EP coefficient")

    # training steps
    add_argument("clm_steps",
                 dtype=str,
                 default="",
                 msg="Causal prediction steps (CLM)")
    add_argument("mlm_steps",
                 dtype=str,
                 default="",
                 msg="Masked prediction steps (MLM / TLM)")
    add_argument("mt_steps",
                 dtype=str,
                 default="",
                 msg="Machine translation steps")
    add_argument("ae_steps",
                 dtype=str,
                 default="",
                 msg="Denoising auto-encoder steps")
    add_argument("bt_steps",
                 dtype=str,
                 default="",
                 msg="Back-translation steps")
    add_argument("ep_steps",
                 dtype=str,
                 default="",
                 msg="EAT-plain reconstruction steps")
    add_argument("pc_steps",
                 dtype=str,
                 default="",
                 msg="Parallel classification steps")

    # reload pretrained embeddings / pretrained model / checkpoint
    add_argument("reload_emb",
                 dtype=str,
                 default="",
                 msg="Reload pretrained word embeddings")
    add_argument("reload_model",
                 dtype=str,
                 default="",
                 msg="Reload a pretrained model")
    add_argument("reload_checkpoint",
                 dtype=str,
                 default="",
                 msg="Reload a checkpoint")

    # beam search (for MT only)
    add_argument("beam_size",
                 dtype=int,
                 default=1,
                 msg="Beam size, default = 1 (greedy decoding)")
    add_argument(
        "length_penalty",
        dtype=float,
        default=1,
        msg=
        "Length penalty, values < 1.0 favor shorter sentences, while values > 1.0 favor longer ones."
    )
    add_argument(
        "early_stopping",
        dtype=bool,
        default=False,
        msg=
        "Early stopping, stop as soon as we have `beam_size` hypotheses, although longer ones may have better scores."
    )

    # evaluation
    add_argument("eval_bleu",
                 dtype=bool,
                 default=False,
                 msg="Evaluate BLEU score during MT training")
    add_argument("eval_only",
                 dtype=bool,
                 default=False,
                 msg="Only run evaluations")

    # debug
    add_argument("debug_train",
                 dtype=bool,
                 default=False,
                 msg="Use valid sets for train sets (faster loading)")
    add_argument("debug_slurm",
                 dtype=bool,
                 default=False,
                 msg="Debug multi-GPU / multi-node within a SLURM job")
    add_argument("debug",
                 msg="Enable all debug flags",
                 dtype=bool,
                 default=False)

    # multi-gpu / multi-node
    add_argument("local_rank",
                 dtype=int,
                 default=-1,
                 msg="Multi-GPU - Local rank")
    add_argument("master_port",
                 dtype=int,
                 default=-1,
                 msg="Master port (for multi-node SLURM jobs)")

    # Add registry.
    add_registry(reg)
Пример #12
0
class LM(nn.Module):

    add_argument('weighted_loss',
                 default='',
                 dtype=str,
                 choices=['', 'mr', 'ot'],
                 msg='what type of weighted loss to use')

    def __init__(self, new_style: 'p', weighted_loss: 'p'):
        super().__init__()
        self.encoder = Encoder()
        self.predictor = Predictor()
        # if weighted_loss and not new_style:
        #     raise ValueError('Must use new_style if using weighted loss')

    def forward(self, batch: IpaBatch) -> Dict[Cat, FT]:
        """
        First encode the `feat_matrix` into a vector `h`, then based on it predict the distributions of features.
        """
        h = self.encoder(batch.feat_matrix, batch.pos_to_predict,
                         batch.source_padding)
        distr = self.predictor(h)
        return distr

    def score(self, batch) -> Dict[Cat, FT]:
        distr = self(batch)
        scores = dict()
        for name, output in distr.items():
            i = get_index(name, new_style=self.new_style)
            target = batch.target_feat[:, i]
            weight = batch.target_weight[:, i]

            if self.weighted_loss == '':
                log_probs = gather(output, target)
                score = -log_probs
            else:
                e = get_new_style_enum(i)
                mat = get_tensor(e.get_distance_matrix())
                mat = mat[target.rename(None)]
                if self.weighted_loss == 'mr':
                    mat_exp = torch.where(mat > 0, (mat + 1e-8).log(),
                                          get_zeros(mat.shape).fill_(-99.9))
                    logits = mat_exp + output
                    # NOTE(j_luo) For the categories except Ptype, the sums of probs are not 1.0 (they are conditioned on certain values of Ptyle).
                    # As a result, we need to incur penalties based on the remaining prob mass as well.
                    # Specifically, the remaining prob mass will result in a penalty of 1.0, which is e^(0.0).
                    none_probs = (
                        1.0 -
                        output.exp().sum(dim=-1, keepdims=True)).clamp(min=0.0)
                    none_penalty = (1e-8 + none_probs).log().align_as(output)
                    logits = torch.cat([logits, none_penalty], dim=-1)
                    score = torch.logsumexp(logits, dim=-1).exp()
                elif self.weighted_loss == 'ot':
                    if not self.training:
                        raise RuntimeError('Cannot use OT for training.')

                    probs = output.exp()
                    # We have to incur penalties based on the remaining prob mass as well.
                    none_probs = (1.0 -
                                  probs.sum(dim=-1, keepdims=True)).clamp(
                                      min=0.0)
                    mat = torch.cat([
                        mat,
                        get_tensor(torch.ones_like(none_probs.rename(None)))
                    ],
                                    dim=-1)
                    probs = torch.cat([probs, none_probs], dim=-1)
                    score = (mat * probs).sum(dim=-1)
                else:
                    raise ValueError(f'Cannot recognize {self.weighted_loss}.')
            scores[name] = (score, weight)
        return scores

    @not_supported_argument_value('new_style', True)
    def predict(self, batch, k=-1) -> Dict[Cat, Tuple[FT, LT, np.ndarray]]:
        """
        Predict the top K results for each feature group.
        If k == -1, then everything would be sorted and returned, otherwise take the topk.
        """
        ret = dict()
        distr = self(batch)
        for cat, log_probs in distr.items():
            e = get_enum_by_cat(cat)
            name = cat.name.lower()
            max_k = log_probs.size(name)
            this_k = max_k if k == -1 else min(max_k, k)
            top_values, top_indices = log_probs.topk(this_k, dim=-1)
            top_cats = np.asarray([
                e.get(i) for i in top_indices.view(-1).cpu().numpy()
            ]).reshape(*top_indices.shape)
            ret[name] = (top_values, top_indices, top_cats)
        return ret
Пример #13
0
        TestSet.MSCOCO_2017
    },
    'cs': {TestSet.FLICKR_2016}
}


@dataclass(frozen=True)
class Key:
    main: str
    lang: str


if __name__ == "__main__":
    initiate(logger=True, gpus=True, commit_id=True, random_seed=True)

    add_argument('langs', dtype=str, nargs=2)
    add_argument('codes', dtype=str)
    add_argument('seed', dtype=int, default=1234)
    add_argument('split_lines', dtype=int, nargs=2)
    add_argument('eat', dtype=str, default='', choices=['', 'eat', 'neo'])
    add_argument('dataset',
                 dtype=str,
                 default='multi30k',
                 choices=['multi30k', 'iwslt'])
    add_argument('linear', dtype=bool, default=False)
    add_argument('graph', dtype=bool, default=False, msg='for eat')
    add_argument('pair', dtype=str, default='')
    parse_args(show=True)

    # Use random seed to make sure train split is persistent.
    random.seed(g.seed)
Пример #14
0
import sys
from pathlib import Path

import torch
from devlib import initiate
from arglib import add_argument, parse_args, g

from xib.ipa.process import (apply_all, clean_data, get_ipa_data,
                             get_pth_content, indexify, merge)

if __name__ == "__main__":
    initiate(logger=True)
    add_argument('in_path', dtype='path')
    add_argument('lang', dtype=str)
    parse_args()

    with g.in_path.open('r', encoding='utf8') as fin:
        cnt, total, df = get_ipa_data(fin, progress=True)
        print(f'Ignore {cnt} / {total} lines.')

    folder: Path = g.in_path.parent

    apply_all(df, progress=True)
    cleaned_df = clean_data(df, progress=True)

    cleaned_df.to_csv(folder / f'phones_{g.lang}.tsv', sep='\t', index=False)

    merged_df = merge(cleaned_df, progress=True)

    # Save intermediate merged results.
    merged_df.to_csv(folder / f'phones_merged_{g.lang}.tsv',
Пример #15
0
class HashingMemory(nn.Module):

    MEM_VALUES_PARAMS = '.values.weight'
    VALUES = None
    EVAL_MEMORY = True
    _ids = itertools.count(0)

    # ------------------------ Register memory parameters ------------------------ #

    add_argument("use_memory",
                 dtype=bool,
                 default=False,
                 msg="Use an external memory")
    add_argument(
        "mem_enc_positions",
        dtype=str,
        default="",
        msg=
        "Memory positions in the encoder ('4' for inside layer 4, '7,10+' for inside layer 7 and after layer 10)"
    )
    add_argument(
        "mem_dec_positions",
        dtype=str,
        default="",
        msg=
        "Memory positions in the decoder. Same syntax as `mem_enc_positions`.")
    # memory implementation
    add_argument("mem_implementation",
                 dtype=str,
                 default="pq_fast",
                 msg="Memory implementation (flat, pq_default, pq_fast)")

    # optimization
    add_argument("mem_grouped_conv",
                 dtype=bool,
                 default=False,
                 msg="Use grouped convolutions in the query network")
    add_argument("mem_values_optimizer",
                 dtype=str,
                 default="adam,lr=0.001",
                 msg="Memory values optimizer ("
                 " for the same optimizer as the rest of the model)")
    add_argument("mem_sparse",
                 dtype=bool,
                 default=False,
                 msg="Perform sparse updates for the values")

    # global parameters
    add_argument("mem_input2d",
                 dtype=bool,
                 default=False,
                 msg="Convolutional query network")
    add_argument("mem_k_dim",
                 dtype=int,
                 default=256,
                 msg="Memory keys dimension")
    add_argument(
        "mem_v_dim",
        dtype=int,
        default=-1,
        msg="Memory values dimension (-1 for automatic output dimension)")
    add_argument("mem_heads",
                 dtype=int,
                 default=4,
                 msg="Number of memory reading heads")
    add_argument(
        "mem_knn",
        dtype=int,
        default=32,
        msg="Number of memory slots to read / update - k-NN to the query")
    add_argument("mem_share_values",
                 dtype=bool,
                 default=False,
                 msg="Share values across memories")
    add_argument("mem_shuffle_indices",
                 dtype=bool,
                 default=False,
                 msg="Shuffle indices for different heads")
    add_argument(
        "mem_shuffle_query",
        dtype=bool,
        default=False,
        msg=
        "Shuffle query dimensions (when the query network is the identity and there are multiple heads)"
    )
    add_argument(
        "mem_modulo_size",
        dtype=int,
        default=-1,
        msg=
        "Effective memory size: indices are taken modulo this parameter. -1 to disable."
    )

    # keys
    add_argument("mem_keys_dtype",
                 dtype=str,
                 default="uniform",
                 msg="Memory keys dtype (binary,gaussian,uniform)")
    add_argument("mem_n_keys", dtype=int, default=512, msg="Number of keys")
    add_argument("mem_keys_normalized_init",
                 dtype=bool,
                 default=False,
                 msg="Normalize keys at initialization")
    add_argument("mem_keys_learn", dtype=bool, default=True, msg="Learn keys")
    add_argument("mem_use_different_keys",
                 dtype=bool,
                 default=True,
                 msg="Use different keys for each head / product quantization")

    # queries
    add_argument("mem_query_detach_input",
                 dtype=bool,
                 default=False,
                 msg="Detach input")
    add_argument("mem_query_layer_sizes",
                 dtype=str,
                 default="0,0",
                 msg="Query MLP layer sizes ('', '0,0', '0,512,0')")
    add_argument("mem_query_kernel_sizes",
                 dtype=str,
                 default="",
                 msg="Query MLP kernel sizes (2D inputs only)")
    add_argument("mem_query_bias",
                 dtype=bool,
                 default=True,
                 msg="Query MLP bias")
    add_argument("mem_query_batchnorm",
                 dtype=bool,
                 default=False,
                 msg="Query MLP batch norm")
    add_argument("mem_query_net_learn",
                 dtype=bool,
                 default=True,
                 msg="Query MLP learn")
    add_argument("mem_query_residual",
                 dtype=bool,
                 default=False,
                 msg="Use a bottleneck with a residual layer in the query MLP")
    add_argument("mem_multi_query_net",
                 dtype=bool,
                 default=False,
                 msg="Use multiple query MLP (one for each head)")

    # values initialization
    add_argument("mem_value_zero_init",
                 dtype=bool,
                 default=False,
                 msg="Initialize values with zeros")

    # scoring
    add_argument("mem_normalize_query",
                 dtype=bool,
                 default=False,
                 msg="Normalize queries")
    add_argument("mem_temperature",
                 dtype=float,
                 default=1,
                 msg="Divide scores by a temperature")
    add_argument("mem_score_softmax",
                 dtype=bool,
                 default=True,
                 msg="Apply softmax on scores")
    add_argument("mem_score_subtract",
                 dtype=str,
                 default="",
                 msg="Subtract scores ('', min, mean, median)")
    add_argument("mem_score_normalize",
                 dtype=bool,
                 default=False,
                 msg="L1 normalization of the scores")

    # dropout
    add_argument("mem_input_dropout",
                 dtype=float,
                 default=0,
                 msg="Input dropout")
    add_argument("mem_query_dropout",
                 dtype=float,
                 default=0,
                 msg="Query dropout")
    add_argument("mem_value_dropout",
                 dtype=float,
                 default=0,
                 msg="Value dropout")

    def __init__(self, input_dim, output_dim, params):

        super().__init__()
        self.id = next(self._ids)

        # global parameters
        self.input2d = params.mem_input2d
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.size = params.mem_size
        self.modulo_size = params.mem_modulo_size
        self.n_indices = params.n_indices
        self.k_dim = params.mem_k_dim
        self.v_dim = params.mem_v_dim if params.mem_v_dim > 0 else output_dim
        self.heads = params.mem_heads
        self.knn = params.mem_knn
        self.shuffle_indices = params.mem_shuffle_indices
        self.keys_normalized_init = params.mem_keys_normalized_init
        self.product_quantization = params.mem_product_quantization
        assert self.modulo_size == -1 and self.size == self.n_indices or self.n_indices > self.size == self.modulo_size >= 1

        # keys / queries
        self.keys_type = params.mem_keys_type
        self.learn_keys = params.mem_keys_learn
        self.use_different_keys = params.mem_use_different_keys
        self.query_detach_input = params.mem_query_detach_input
        self.query_net_learn = params.mem_query_net_learn
        self.multi_query_net = params.mem_multi_query_net
        self.shuffle_query = params.mem_shuffle_query
        assert self.use_different_keys is False or self.keys_type in [
            'gaussian', 'uniform'
        ]
        assert self.use_different_keys is False or self.heads >= 2 or self.product_quantization
        assert self.multi_query_net is False or self.heads >= 2 or self.product_quantization
        assert self.shuffle_query is False or self.heads > 1 and params.mem_query_layer_sizes == ''
        assert self.shuffle_query is False or self.input_dim % (
            2**self.heads) == 0

        # scoring / re-scoring
        self.normalize_query = params.mem_normalize_query
        self.temperature = params.mem_temperature
        self.score_softmax = params.mem_score_softmax
        self.score_subtract = params.mem_score_subtract
        self.score_normalize = params.mem_score_normalize
        assert self.score_subtract in ['', 'min', 'mean', 'median']
        assert self.score_subtract == '' or self.knn >= 2
        assert not (self.score_normalize and self.score_softmax
                    and self.score_subtract == '')

        # dropout
        self.input_dropout = params.mem_input_dropout
        self.query_dropout = params.mem_query_dropout
        self.value_dropout = params.mem_value_dropout

        # initialize keys
        self.init_keys()

        # self.values = nn.Embedding(self.size, self.v_dim, sparse=params.mem_sparse)
        self.values = nn.EmbeddingBag(self.size,
                                      self.v_dim,
                                      mode='sum',
                                      sparse=params.mem_sparse)

        # optionally use the same values for all memories
        if params.mem_share_values:
            if HashingMemory.VALUES is None:
                HashingMemory.VALUES = self.values.weight
            else:
                self.values.weight = HashingMemory.VALUES

        # values initialization
        if params.mem_value_zero_init:
            nn.init.zeros_(self.values.weight)
        else:
            nn.init.normal_(self.values.weight, mean=0, std=self.v_dim**-0.5)

        # no query network
        if len(params.mem_query_layer_sizes) == 0:
            assert self.heads == 1 or self.use_different_keys or self.shuffle_query
            assert self.input_dim == self.k_dim
            self.query_proj = QueryIdentity(self.input_dim, self.heads,
                                            self.shuffle_query)

        # query network
        if len(params.mem_query_layer_sizes) > 0:
            assert not self.shuffle_query

            # layer sizes / number of features
            l_sizes = list(params.mem_query_layer_sizes)
            assert len(l_sizes) >= 2 and l_sizes[0] == l_sizes[-1] == 0
            l_sizes[0] = self.input_dim
            l_sizes[-1] = (self.k_dim //
                           2) if self.multi_query_net else (self.heads *
                                                            self.k_dim)

            # convolutional or feedforward
            if self.input2d:
                self.query_proj = QueryConv(
                    self.input_dim,
                    self.heads,
                    self.k_dim,
                    self.product_quantization,
                    self.multi_query_net,
                    l_sizes,
                    params.mem_query_kernel_sizes,
                    bias=params.mem_query_bias,
                    batchnorm=params.mem_query_batchnorm,
                    grouped_conv=params.mem_grouped_conv)
            else:
                assert params.mem_query_kernel_sizes == ''
                assert not params.mem_query_residual
                self.query_proj = QueryMLP(
                    self.input_dim,
                    self.heads,
                    self.k_dim,
                    self.product_quantization,
                    self.multi_query_net,
                    l_sizes,
                    bias=params.mem_query_bias,
                    batchnorm=params.mem_query_batchnorm,
                    grouped_conv=params.mem_grouped_conv)

        # shuffle indices for different heads
        if self.shuffle_indices:
            head_permutations = [
                torch.randperm(self.n_indices).unsqueeze(0)
                for i in range(self.heads)
            ]
            self.register_buffer('head_permutations',
                                 torch.cat(head_permutations, 0))

        # do not learn the query network
        if self.query_net_learn is False:
            for p in self.query_proj.parameters():
                p.requires_grad = False

    def forward(self, input):
        """
        Read from the memory.
        """
        # detach input
        if self.query_detach_input:
            input = input.detach()

        # input dimensions
        if self.input2d:
            assert input.shape[1] == self.input_dim
            n_images, _, height, width = input.shape
            prefix_shape = (n_images, width, height)
        else:
            assert input.shape[-1] == self.input_dim
            prefix_shape = input.shape[:-1]

        # compute query / store it
        bs = np.prod(prefix_shape)
        input = F.dropout(input, p=self.input_dropout,
                          training=self.training)  # input shape
        query = self.query_proj(input)  # (bs * heads, k_dim)
        query = F.dropout(query, p=self.query_dropout,
                          training=self.training)  # (bs * heads, k_dim)
        assert query.shape == (bs * self.heads, self.k_dim)

        # get indices
        scores, indices = self.get_indices(query,
                                           self.knn)  # (bs * heads, knn) ** 2

        # optionally shuffle indices for different heads
        if self.shuffle_indices:
            indices = indices.view(bs, self.heads, -1).chunk(self.heads, 1)
            indices = [
                p[idx] for p, idx in zip(self.head_permutations, indices)
            ]
            indices = torch.cat(indices, 1).view(bs * self.heads, -1)

        # take indices modulo the memory size
        if self.modulo_size != -1:
            indices = indices % self.modulo_size

        # re-scoring
        if self.temperature != 1:
            scores = scores / self.temperature  # (bs * heads, knn)
        if self.score_softmax:
            scores = F.softmax(scores.float(),
                               dim=-1).type_as(scores)  # (bs * heads, knn)
        if self.score_subtract != '':
            if self.score_subtract == 'min':
                to_sub = scores.min(1, keepdim=True)[0]  # (bs * heads, 1)
            if self.score_subtract == 'mean':
                to_sub = scores.mean(1, keepdim=True)  # (bs * heads, 1)
            if self.score_subtract == 'median':
                to_sub = scores.median(1, keepdim=True)[0]  # (bs * heads, 1)
            scores = scores - to_sub  # (bs * heads, knn)
        if self.score_normalize:
            scores = scores / scores.norm(p=1, dim=1,
                                          keepdim=True)  # (bs * heads, knn)

        # merge heads / knn (since we sum heads)
        indices = indices.view(bs, self.heads * self.knn)  # (bs, heads * knn)
        scores = scores.view(bs, self.heads * self.knn)  # (bs, heads * knn)

        # weighted sum of values
        # output = self.values(indices) * scores.unsqueeze(-1)                    # (bs * heads, knn, v_dim)
        # output = output.sum(1)                                                  # (bs * heads, v_dim)
        output = self.values(
            indices, per_sample_weights=scores.to(self.values.weight.data)).to(
                scores)  # (bs, v_dim)
        output = F.dropout(output,
                           p=self.value_dropout,
                           training=self.training)  # (bs, v_dim)

        # reshape output
        if self.input2d:
            output = output.view(
                n_images, width, height,
                self.v_dim)  # (n_images, width, height, v_dim)
            output = output.transpose(1, 3)  # (n_images, v_dim, height, width)
        else:
            if len(prefix_shape) >= 2:
                output = output.view(prefix_shape +
                                     (self.v_dim, ))  # (..., v_dim)

        # store indices / scores (eval mode only - for usage statistics)
        if not self.training and HashingMemory.EVAL_MEMORY:
            self.last_indices = indices.view(bs, self.heads,
                                             self.knn).detach().cpu()
            self.last_scores = scores.view(bs, self.heads,
                                           self.knn).detach().cpu().float()

        return output

    def init_keys(self):
        raise Exception("Not implemented!")

    def _get_indices(self, query, knn, keys):
        raise Exception("Not implemented!")

    def get_indices(self, query, knn):
        raise Exception("Not implemented!")

    @staticmethod
    def build(input_dim, output_dim, params):
        if params.mem_implementation == 'flat':
            M = HashingMemoryFlat
        elif params.mem_implementation == 'pq_default':
            M = HashingMemoryProduct
        elif params.mem_implementation == 'pq_fast':
            M = HashingMemoryProductFast
        else:
            raise Exception("Unknown memory implementation!")
        return M(input_dim, output_dim, params)

    @staticmethod
    def check_params(params):
        """
        Check and initialize memory parameters.
        """
        # memory
        assert params.mem_implementation in ['flat', 'pq_default', 'pq_fast']
        params.mem_product_quantization = params.mem_implementation != 'flat'

        # optimization
        assert params.mem_grouped_conv is False or params.mem_multi_query_net
        params.mem_values_optimizer = params.optimizer if params.mem_values_optimizer == '' else params.mem_values_optimizer
        params.mem_values_optimizer = params.mem_values_optimizer.replace(
            'adam',
            'sparseadam') if params.mem_sparse else params.mem_values_optimizer

        # even number of key dimensions for product quantization
        assert params.mem_k_dim >= 2
        assert params.mem_product_quantization is False or params.mem_k_dim % 2 == 0

        # memory type
        assert params.mem_keys_type in ['binary', 'gaussian', 'uniform']

        # number of indices
        if params.mem_keys_type == 'binary':
            assert params.mem_keys_normalized_init is False
            assert 1 << params.mem_k_dim == params.mem_n_keys
        if params.mem_product_quantization:
            params.n_indices = params.mem_n_keys**2
        else:
            params.n_indices = params.mem_n_keys

        # actual memory size
        if params.mem_modulo_size == -1:
            params.mem_size = params.n_indices
        else:
            assert 1 <= params.mem_modulo_size < params.n_indices
            params.mem_size = params.mem_modulo_size

        # different keys / different query MLP / shuffle hidden dimensions when no query network
        assert not params.mem_use_different_keys or params.mem_keys_type in [
            'gaussian', 'uniform'
        ]
        assert not params.mem_use_different_keys or params.mem_heads >= 2 or params.mem_product_quantization
        assert not params.mem_multi_query_net or params.mem_heads >= 2 or params.mem_product_quantization
        assert not params.mem_multi_query_net or params.mem_query_layer_sizes not in [
            '', '0,0'
        ]
        assert not params.mem_shuffle_query or params.mem_heads > 1 and params.mem_query_layer_sizes == ''

        # query network
        if params.mem_query_layer_sizes == '':
            assert params.mem_heads == 1 or params.mem_use_different_keys or params.mem_shuffle_query
        else:
            s = [
                int(x)
                for x in filter(None, params.mem_query_layer_sizes.split(','))
            ]
            assert len(s) >= 2 and s[0] == s[-1] == 0
            params.mem_query_layer_sizes = s
            assert not params.mem_query_residual or params.mem_input2d

        # convolutional query network kernel sizes
        if params.mem_query_kernel_sizes == '':
            assert not params.mem_input2d or params.mem_query_layer_sizes == ''
        else:
            assert params.mem_input2d
            s = [
                int(x)
                for x in filter(None, params.mem_query_kernel_sizes.split(','))
            ]
            params.mem_query_kernel_sizes = s
            assert all(ks % 2 == 1 for ks in s)
            assert len(params.mem_query_kernel_sizes) == len(
                params.mem_query_layer_sizes) - 1 >= 1

        # scoring
        assert params.mem_score_subtract in ['', 'min', 'mean', 'median']
        assert params.mem_score_subtract == '' or params.mem_knn >= 2
        assert not (params.mem_score_normalize and params.mem_score_softmax
                    and params.mem_score_subtract == '')

        # dropout
        assert 0 <= params.mem_input_dropout < 1
        assert 0 <= params.mem_query_dropout < 1
        assert 0 <= params.mem_value_dropout < 1

        # query batchnorm
        if params.mem_query_batchnorm:
            logger.warning(
                "WARNING: if you use batch normalization, be sure that you use batches of sentences with the same size at training time. Otherwise, the padding token will result in incorrect mean/variance estimations in the BatchNorm layer."
            )
Пример #16
0
class TransformerModel(nn.Module):

    ATTRIBUTES = ['encoder', 'with_output', 'eos_index', 'pad_index', 'n_langs', 'n_words', 'dim', 'n_layers',
                  'n_heads', 'hidden_dim', 'dropout', 'attention_dropout', 'asm', 'asm_cutoffs', 'asm_div_value']

    add_argument('use_positional_embedding', default=True, dtype=bool,
                 msg='whether to use positional embedding or not.')

    def __init__(self, params, dico, is_encoder, with_output):
        """
        Transformer model (encoder or decoder).
        """
        super().__init__()

        # encoder / decoder, output layer
        self.is_encoder = is_encoder
        self.is_decoder = not is_encoder
        self.with_output = with_output

        # dictionary / languages
        self.n_langs = params.n_langs
        self.n_words = params.n_words
        self.eos_index = params.eos_index
        self.pad_index = params.pad_index
        self.dico = dico
        self.id2lang = params.id2lang
        self.lang2id = params.lang2id
        self.use_lang_emb = getattr(params, 'use_lang_emb', True)
        assert len(self.dico) == self.n_words
        assert len(self.id2lang) == len(self.lang2id) == self.n_langs

        # model parameters
        self.dim = params.emb_dim       # 512 by default
        self.hidden_dim = self.dim * 4  # 2048 by default
        self.n_heads = params.n_heads   # 8 by default
        self.n_layers = params.n_layers
        self.dropout = params.dropout
        self.attention_dropout = params.attention_dropout
        assert self.dim % self.n_heads == 0, 'transformer dim must be a multiple of n_heads'

        # embeddings
        self.position_embeddings = Embedding(N_MAX_POSITIONS, self.dim)
        if params.sinusoidal_embeddings:
            create_sinusoidal_embeddings(N_MAX_POSITIONS, self.dim, out=self.position_embeddings.weight)
        if params.n_langs > 1 and self.use_lang_emb:
            self.lang_embeddings = Embedding(self.n_langs, self.dim)
        self.embeddings = Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
        self.layer_norm_emb = nn.LayerNorm(self.dim, eps=1e-12)

        # transformer layers
        self.attentions = nn.ModuleList()
        self.layer_norm1 = nn.ModuleList()
        self.ffns = nn.ModuleList()
        self.layer_norm2 = nn.ModuleList()
        if self.is_decoder:
            self.layer_norm15 = nn.ModuleList()
            self.encoder_attn = nn.ModuleList()

        # memories
        self.memories = nn.ModuleDict()
        if getattr(params, 'use_memory', False):
            mem_positions = params.mem_enc_positions if is_encoder else params.mem_dec_positions
            for layer_id, pos in mem_positions:
                assert 0 <= layer_id <= params.n_layers - 1
                assert pos in ['in', 'after']
                self.memories['%i_%s' % (layer_id, pos)] = HashingMemory.build(self.dim, self.dim, params)

        for layer_id in range(self.n_layers):
            self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            self.layer_norm1.append(nn.LayerNorm(self.dim, eps=1e-12))
            if self.is_decoder:
                self.layer_norm15.append(nn.LayerNorm(self.dim, eps=1e-12))
                self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
            if ('%i_in' % layer_id) in self.memories:
                self.ffns.append(None)
            else:
                self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim,
                                                dropout=self.dropout, gelu_activation=params.gelu_activation))
            self.layer_norm2.append(nn.LayerNorm(self.dim, eps=1e-12))

        # output layer
        if self.with_output:
            self.pred_layer = PredLayer(params)
            if params.share_inout_emb:
                self.pred_layer.proj.weight = self.embeddings.weight

        self.use_positional_embedding = params.use_positional_embedding

    def forward(self, mode, **kwargs):
        """
        Forward function with different forward modes.
        ### Small hack to handle PyTorch distributed.
        """
        if mode == 'fwd':
            return self.fwd(**kwargs)
        elif mode == 'predict':
            return self.predict(**kwargs)
        else:
            raise Exception("Unknown mode: %s" % mode)

    def fwd(self, x, lengths, causal, src_enc=None, src_len=None, positions=None, langs=None, cache=None):
        """
        Inputs:
            `x` LongTensor(slen, bs), containing word indices
            `lengths` LongTensor(bs), containing the length of each sentence
            `causal` Boolean, if True, the attention is only done over previous hidden states
            `positions` LongTensor(slen, bs), containing word positions
            `langs` LongTensor(slen, bs), containing language IDs
        """
        # lengths = (x != self.pad_index).float().sum(dim=1)
        # mask = x != self.pad_index

        # check inputs
        slen, bs = x.size()
        assert lengths.size(0) == bs
        assert lengths.max().item() <= slen
        x = x.transpose(0, 1)  # batch size as dimension 0
        assert (src_enc is None) == (src_len is None)
        if src_enc is not None:
            assert self.is_decoder
            assert src_enc.size(0) == bs

        # generate masks
        mask, attn_mask = get_masks(slen, lengths, causal)
        if self.is_decoder and src_enc is not None:
            src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]

        # positions
        if positions is None:
            positions = x.new(slen).long()
            positions = torch.arange(slen, out=positions).unsqueeze(0)
        else:
            assert positions.size() == (slen, bs)
            positions = positions.transpose(0, 1)

        # langs
        if langs is not None:
            assert langs.size() == (slen, bs)
            langs = langs.transpose(0, 1)

        # do not recompute cached elements
        if cache is not None:
            _slen = slen - cache['slen']
            x = x[:, -_slen:]
            positions = positions[:, -_slen:]
            if langs is not None:
                langs = langs[:, -_slen:]
            mask = mask[:, -_slen:]
            attn_mask = attn_mask[:, -_slen:]

        # embeddings
        tensor = self.embeddings(x)
        if self.use_positional_embedding:
            tensor = tensor + self.position_embeddings(positions).expand_as(tensor)
        if langs is not None and self.use_lang_emb:
            tensor = tensor + self.lang_embeddings(langs)
        tensor = self.layer_norm_emb(tensor)
        tensor = F.dropout(tensor, p=self.dropout, training=self.training)
        tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # transformer layers
        for i in range(self.n_layers):

            # self attention
            attn = self.attentions[i](tensor, attn_mask, cache=cache)
            attn = F.dropout(attn, p=self.dropout, training=self.training)
            tensor = tensor + attn
            tensor = self.layer_norm1[i](tensor)

            # encoder attention (for decoder only)
            if self.is_decoder and src_enc is not None:
                attn = self.encoder_attn[i](tensor, src_mask, kv=src_enc, cache=cache)
                attn = F.dropout(attn, p=self.dropout, training=self.training)
                tensor = tensor + attn
                tensor = self.layer_norm15[i](tensor)

            # FFN
            if ('%i_in' % i) in self.memories:
                tensor = tensor + self.memories['%i_in' % i](tensor)
            else:
                tensor = tensor + self.ffns[i](tensor)
            tensor = self.layer_norm2[i](tensor)

            # memory
            if ('%i_after' % i) in self.memories:
                tensor = tensor + self.memories['%i_after' % i](tensor)
            # TODO: add extra layer norm here?

            tensor *= mask.unsqueeze(-1).to(tensor.dtype)

        # update cache length
        if cache is not None:
            cache['slen'] += tensor.size(1)

        # move back sequence length to dimension 0
        tensor = tensor.transpose(0, 1)

        return tensor

    def predict(self, tensor, pred_mask, y, get_scores):
        """
        Given the last hidden state, compute word scores and/or the loss.
            `pred_mask` is a ByteTensor of shape (slen, bs), filled with 1 when
                we need to predict a word
            `y` is a LongTensor of shape (pred_mask.sum(),)
            `get_scores` is a boolean specifying whether we need to return scores
        """
        masked_tensor = tensor[pred_mask.unsqueeze(-1).expand_as(tensor)].view(-1, self.dim)
        scores, loss = self.pred_layer(masked_tensor, y, get_scores)
        return scores, loss

    def generate(self, src_enc, src_len, tgt_lang_id, max_len=200, sample_temperature=None):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        # input batch
        bs = len(src_len)
        assert src_enc.size(0) == bs

        # generated sentences
        generated = src_len.new(max_len, bs)  # upcoming output
        generated.fill_(self.pad_index)       # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)    # we use <EOS> for <BOS> everywhere

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand(max_len, bs)

        # language IDs
        langs = src_len.new(max_len).long().fill_(tgt_lang_id)
        langs = langs.unsqueeze(1).expand(max_len, bs)

        # current position / max lengths / length of generated sentences / unfinished sentences
        cur_len = 1
        gen_len = src_len.clone().fill_(1)
        unfinished_sents = src_len.clone().fill_(1)

        # cache compute states
        cache = {'slen': 0}

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                'fwd',
                x=generated[:cur_len],
                lengths=gen_len,
                positions=positions[:cur_len],
                langs=langs[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                cache=cache
            )
            assert tensor.size() == (1, bs, self.dim), (cur_len, max_len, src_enc.size(), tensor.size(), (1, bs, self.dim))
            tensor = tensor.data[-1, :, :].type_as(src_enc)  # (bs, dim)
            scores = self.pred_layer.get_scores(tensor)      # (bs, n_words)

            # select next words: sample or greedy
            if sample_temperature is None:
                next_words = torch.topk(scores, 1)[1].squeeze(1)
            else:
                next_words = torch.multinomial(F.softmax(scores / sample_temperature, dim=1), 1).squeeze(1)
            assert next_words.size() == (bs,)

            # update generations / lengths / finished sentences / current length
            generated[cur_len] = next_words * unfinished_sents + self.pad_index * (1 - unfinished_sents)
            gen_len.add_(unfinished_sents)
            unfinished_sents.mul_(next_words.ne(self.eos_index).long())
            cur_len = cur_len + 1

            # stop when there is a </s> in each sentence, or if we exceed the maximul length
            if unfinished_sents.max() == 0:
                break

        # add <EOS> to unfinished sentences
        if cur_len == max_len:
            generated[-1].masked_fill_(unfinished_sents.bool(), self.eos_index)

        # sanity check
        assert (generated == self.eos_index).sum() == 2 * bs

        return generated[:cur_len], gen_len

    def generate_beam(self, src_enc, src_len, tgt_lang_id, beam_size, length_penalty, early_stopping, max_len=200):
        """
        Decode a sentence given initial start.
        `x`:
            - LongTensor(bs, slen)
                <EOS> W1 W2 W3 <EOS> <PAD>
                <EOS> W1 W2 W3   W4  <EOS>
        `lengths`:
            - LongTensor(bs) [5, 6]
        `positions`:
            - False, for regular "arange" positions (LM)
            - True, to reset positions from the new generation (MT)
        `langs`:
            - must be None if the model only supports one language
            - lang_id if only one language is involved (LM)
            - (lang_id1, lang_id2) if two languages are involved (MT)
        """

        # check inputs
        assert src_enc.size(0) == src_len.size(0)
        assert beam_size >= 1

        # batch size / number of words
        bs = len(src_len)
        n_words = self.n_words

        # expand to beam size the source latent representations / source lengths
        src_enc = src_enc.unsqueeze(1).expand(
            (bs, beam_size) + src_enc.shape[1:]).contiguous().view((bs * beam_size,) + src_enc.shape[1:])
        src_len = src_len.unsqueeze(1).expand(bs, beam_size).contiguous().view(-1)

        # generated sentences (batch with beam current hypotheses)
        generated = src_len.new(max_len, bs * beam_size)  # upcoming output
        generated.fill_(self.pad_index)                   # fill upcoming ouput with <PAD>
        generated[0].fill_(self.eos_index)                # we use <EOS> for <BOS> everywhere

        # generated hypotheses
        generated_hyps = [BeamHypotheses(beam_size, max_len, length_penalty, early_stopping) for _ in range(bs)]

        # positions
        positions = src_len.new(max_len).long()
        positions = torch.arange(max_len, out=positions).unsqueeze(1).expand_as(generated)

        # language IDs
        langs = positions.clone().fill_(tgt_lang_id)

        # scores for each sentence in the beam
        beam_scores = src_enc.new(bs, beam_size).fill_(0)
        beam_scores[:, 1:] = -1e9
        beam_scores = beam_scores.view(-1)

        # current position
        cur_len = 1

        # cache compute states
        cache = {'slen': 0}

        # done sentences
        done = [False for _ in range(bs)]

        while cur_len < max_len:

            # compute word scores
            tensor = self.forward(
                'fwd',
                x=generated[:cur_len],
                lengths=src_len.new(bs * beam_size).fill_(cur_len),
                positions=positions[:cur_len],
                langs=langs[:cur_len],
                causal=True,
                src_enc=src_enc,
                src_len=src_len,
                cache=cache
            )
            assert tensor.size() == (1, bs * beam_size, self.dim)
            tensor = tensor.data[-1, :, :]               # (bs * beam_size, dim)
            scores = self.pred_layer.get_scores(tensor)  # (bs * beam_size, n_words)
            scores = F.log_softmax(scores, dim=-1)       # (bs * beam_size, n_words)
            assert scores.size() == (bs * beam_size, n_words)

            # select next words with scores
            _scores = scores + beam_scores[:, None].expand_as(scores)  # (bs * beam_size, n_words)
            _scores = _scores.view(bs, beam_size * n_words)            # (bs, beam_size * n_words)

            next_scores, next_words = torch.topk(_scores, 2 * beam_size, dim=1, largest=True, sorted=True)
            assert next_scores.size() == next_words.size() == (bs, 2 * beam_size)

            # next batch beam content
            # list of (bs * beam_size) tuple(next hypothesis score, next word, current position in the batch)
            next_batch_beam = []

            # for each sentence
            for sent_id in range(bs):

                # if we are done with this sentence
                done[sent_id] = done[sent_id] or generated_hyps[sent_id].is_done(next_scores[sent_id].max().item())
                if done[sent_id]:
                    next_batch_beam.extend([(0, self.pad_index, 0)] * beam_size)  # pad the batch
                    continue

                # next sentence beam content
                next_sent_beam = []

                # next words for this sentence
                for idx, value in zip(next_words[sent_id], next_scores[sent_id]):

                    # get beam and word IDs
                    beam_id = idx // n_words
                    word_id = idx % n_words

                    # end of sentence, or next word
                    if word_id == self.eos_index or cur_len + 1 == max_len:
                        generated_hyps[sent_id].add(
                            generated[:cur_len, sent_id * beam_size + beam_id].clone(), value.item())
                    else:
                        next_sent_beam.append((value, word_id, sent_id * beam_size + beam_id))

                    # the beam for next step is full
                    if len(next_sent_beam) == beam_size:
                        break

                # update next beam content
                assert len(next_sent_beam) == 0 if cur_len + 1 == max_len else beam_size
                if len(next_sent_beam) == 0:
                    next_sent_beam = [(0, self.pad_index, 0)] * beam_size  # pad the batch
                next_batch_beam.extend(next_sent_beam)
                assert len(next_batch_beam) == beam_size * (sent_id + 1)

            # sanity check / prepare next batch
            assert len(next_batch_beam) == bs * beam_size
            beam_scores = beam_scores.new([x[0] for x in next_batch_beam])
            beam_words = generated.new([x[1] for x in next_batch_beam])
            beam_idx = src_len.new([x[2] for x in next_batch_beam])

            # re-order batch and internal states
            generated = generated[:, beam_idx]
            generated[cur_len] = beam_words
            for k in cache.keys():
                if k != 'slen':
                    cache[k] = (cache[k][0][beam_idx], cache[k][1][beam_idx])

            # update current length
            cur_len = cur_len + 1

            # stop when we are done with each sentence
            if all(done):
                break

        # visualize hypotheses
        # print([len(x) for x in generated_hyps], cur_len)
        # globals().update( locals() );
        # !import code; code.interact(local=vars())
        # for ii in range(bs):
        #     for ss, ww in sorted(generated_hyps[ii].hyp, key=lambda x: x[0], reverse=True):
        #         print("%.3f " % ss + " ".join(self.dico[x] for x in ww.tolist()))
        #     print("")

        # select the best hypotheses
        tgt_len = src_len.new(bs)
        best = []

        for i, hypotheses in enumerate(generated_hyps):
            best_hyp = max(hypotheses.hyp, key=lambda x: x[0])[1]
            tgt_len[i] = len(best_hyp) + 1  # +1 for the <EOS> symbol
            best.append(best_hyp)

        # generate target batch
        decoded = src_len.new(tgt_len.max().item(), bs).fill_(self.pad_index)
        for i, hypo in enumerate(best):
            decoded[:tgt_len[i] - 1, i] = hypo
            decoded[tgt_len[i] - 1, i] = self.eos_index

        # sanity check
        assert (decoded == self.eos_index).sum() == 2 * bs

        return decoded, tgt_len
Пример #17
0
class Graphormer(TransformerModel):

    add_argument(
        'ablation_mode',
        default='full',
        dtype=str,
        choices=['full', 'ffn', 'none', 'self_attn'],
        msg=
        'ablation mode. full means full model, ffn replaces rgcn with ffn, and none means using assembler only.'
    )
    add_argument("self_attn_layers",
                 default=0,
                 dtype=int,
                 msg='number of layers for self attention layers.')
    add_argument('num_relations',
                 default=5,
                 dtype=int,
                 msg='number of distinct edge types.')
    add_argument('mask_unconnected_vertices',
                 default=False,
                 dtype=bool,
                 msg='flag to mask unconnected vertices.')

    def __init__(self, params, dico, is_encoder, with_output):
        super().__init__(params, dico, is_encoder, with_output)
        self.assembler = Assembler()
        self.ablation_mode = params.ablation_mode
        self.num_relations = params.num_relations
        self.mask_unconnected_vertices = params.mask_unconnected_vertices

        if self.ablation_mode == 'full':
            self.graph_predictor = GraphPredictor()
            self.rgcn = ContinuousRGCN()
        elif self.ablation_mode == 'ffn':
            # NOTE(j_luo) This actually is about 10x smaller than 'full' mode.
            self.linear = nn.Linear(params.emb_dim, params.emb_dim)
        elif self.ablation_mode == 'none':
            pass
        else:
            if not params.self_attn_layers > 0:
                raise ValueError(f'Must have at least one layer.')
            layers = [
                MultiHeadAttention(8, params.emb_dim, dropout=params.dropout)
                for _ in range(params.self_attn_layers)
            ]
            self.self_attn_layers = nn.Sequential(*layers)

    def fwd(self,
            x,
            lengths,
            causal,
            src_enc=None,
            src_len=None,
            positions=None,
            langs=None,
            cache=None,
            graph_info=None,
            return_graph_data=False,
            oracle_graph=None):
        assert graph_info is not None
        assert not return_graph_data or self.ablation_mode == 'full'

        h = super().fwd(x,
                        lengths,
                        causal,
                        src_enc=src_enc,
                        src_len=src_len,
                        positions=positions,
                        langs=langs,
                        cache=cache)
        h = h.transpose(0, 1)
        assembled_h = self.assembler(h, graph_info.bpe_mask,
                                     graph_info.word2bpe)
        if self.ablation_mode == 'full':
            if oracle_graph is None:
                norms, type_probs = self.graph_predictor(
                    assembled_h, graph_info.word_mask)
                # Prepare node_features, edge_index, edge_norm and edge_type.
                graph_data = self._prepare_for_geometry(
                    assembled_h, norms, type_probs)
            else:
                graph_data = self._prepare_for_geometry(
                    assembled_h, None, None, oracle_graph=oracle_graph)
            output = self.rgcn(x=graph_data.node_features,
                               edge_index=graph_data.edge_index,
                               edge_norm=graph_data.edge_norm,
                               edge_type=graph_data.edge_type)
            # Now reshape output for later usage. Note that the length dimension has changed to represent words instead of BPEs.
            bs, wl, _ = assembled_h.shape
            output[~graph_info.word_mask.view(-1)] = 0.0
            output = output.view(bs, wl, -1)
        elif self.ablation_mode == 'ffn':
            output = self.linear(assembled_h)
            output[~graph_info.word_mask] = 0.0
        elif self.ablation_mode == 'none':
            output = assembled_h
        else:
            output = assembled_h
            for layer in self.self_attn_layers:
                output = layer(output, graph_info.word_mask)
        if self.mask_unconnected_vertices and graph_data.connected_vertices is not None:
            output[~graph_data.connected_vertices] = 0.0

        output = output.transpose(0, 1)
        if return_graph_data:
            return output, graph_data
        else:
            return output

    def _prepare_for_geometry(self,
                              assembled_h,
                              norms,
                              type_probs,
                              oracle_graph: Optional[GraphData] = None):
        """
        inputs:
            assembled_h:    bs x wl x d
            norms:          bs x wl x wl
            type_probs:     bs x wl x wl x nr

        outputs:
            node_features:  V x d
            edge_index:     2 x E
            edge_norms:     E
            edge_types:     E x nr
        where V = bs * wl and E = bs * wl * wl.
        """
        connected_vertices = None
        if oracle_graph is not None:
            type_probs = oracle_graph.edge_type
            norms = oracle_graph.edge_norm
            bs, wl, _ = assembled_h.shape
            connected_vertices = oracle_graph.connected_vertices
        else:
            bs, wl, _, nr = type_probs.shape
        V = bs * wl
        E = bs * wl * wl

        # node_features is just a reshaped version of assembled_h.
        node_features = assembled_h.view(V, -1)

        # edge_index is a collection of fully connected graphs, each of which corresponds to one sentence.
        edge_index_offset = get_range(bs, 1, 0) * wl  # bs
        edge_index_i = get_range(wl, 2, 0).expand(wl, wl)  # wl x 1 -> wl x wl
        edge_index_i = edge_index_offset.view(bs, 1, 1) + edge_index_i
        edge_index_j = get_range(wl, 2, 1).expand(wl, wl)  # 1 x wl -> wl x wl
        edge_index_j = edge_index_offset.view(bs, 1, 1) + edge_index_j
        edge_index = torch.stack([edge_index_i, edge_index_j],
                                 dim=0).view(2, E)

        # edge_norms is just a reshaped version of norms.
        edge_norms = norms.view(E)

        # edge_types is similar.
        if oracle_graph is not None:
            edge_types = get_zeros(E, self.num_relations)
            edge_types.scatter_(1, oracle_graph.edge_type.view(-1, 1),
                                oracle_graph.edge_norm.view(-1, 1))
        else:
            edge_types = type_probs.view(E, nr)

        return GraphData(node_features, edge_index, edge_norms, edge_types,
                         connected_vertices)
Пример #18
0
import logging
import os
import random

from arglib import add_argument, init_g_attr
from trainlib import Metrics, set_random_seeds
from xib.data_loader import (ContinuousTextDataLoader, DenseIpaDataLoader,
                             IpaDataLoader, MetricLearningDataLoader)
from xib.model.decipher_model import DecipherModel
from xib.model.lm_model import LM, AdaptedLM
from xib.model.metric_learning_model import MetricLearningBatch
from xib.training.evaluator import Evaluator, LMEvaluator
from xib.training.trainer import (AdaptLMTrainer, DecipherTrainer, LMTrainer,
                                  MetricLearningTrainer)

add_argument('task', default='lm', dtype=str, choices=['lm', 'decipher', 'metric', 'adapt'], msg='which task to run')


class Manager:

    data_loader_cls = IpaDataLoader
    trainer_cls = LMTrainer

    def __init__(self):
        self.model = self._get_model()
        if os.environ.get('CUDA_VISIBLE_DEVICES', False):
            self.model.cuda()
        self.train_data_loader = self.data_loader_cls()
        self.evaluator = LMEvaluator(self.model, self.train_data_loader)
        self.trainer = self.trainer_cls(self.model, self.train_data_loader, self.evaluator)
Пример #19
0
class Verifier:

    add_argument(
        'ae_noise_graph_mode',
        dtype=str,
        default='keep',
        choices=['keep', 'change'],
        msg=
        'determines how we handle the graph when noised is added to AE and graph supervision is used'
    )

    def __init__(self, data_path, lgs, supervised_graph,
                 ae_noise_graph_mode: 'n', ae_add_noise):
        super().__init__()
        src_lang, tgt_lang = lgs.split('-')
        if supervised_graph:
            self.graphs = dict()
            for lang in [src_lang, tgt_lang]:
                for split in ['train', 'dev', 'test']:
                    name_split = 'valid' if split == 'dev' else split
                    self.graphs[(lang, name_split)] = _read_graphs(
                        data_path / f'{split}.{lang}.tok.cvtx.neo.txt')
        self.dico = torch.load(data_path / f'valid.{src_lang}.pth')['dico']

        self.incomplete_bpe = {lang: set() for lang in [src_lang, tgt_lang]}
        self.incomplete_idx = dict()
        for lang in [src_lang, tgt_lang]:
            incomplete_idx = list()
            for bpe, idx in self.dico.word2id.items():
                if bpe.endswith('@@'):
                    self.incomplete_bpe[lang].add(bpe)
                    incomplete_idx.append(idx)
            idx = get_zeros(len(self.dico)).bool()
            idx[incomplete_idx] = True
            self.incomplete_idx[lang] = idx

        self.ae_noise_graph_mode = ae_noise_graph_mode

    def get_graph_info(self, data, lang: str) -> GraphInfo:
        """
        bpe_mask is the mask that marks the boundaries the words.
        word_mask is the mask that marks whether or not a position is padded.
        word2bpe is the matrix that maps from word ids to bpe positions.
        """
        data = data.t()
        bs, l = data.shape

        data_off_by_one = torch.cat([get_zeros(bs, 1).long(), data[:, :-1]],
                                    dim=1)
        # A new word is started if the previous bpe is complete and it's not a padding or <s>.
        new_word = ~self.incomplete_idx[lang][data_off_by_one] & (
            data != self.dico.pad_index) & (data != self.dico.bos_index)
        # Form distinct word ids by counting how many new words are formed up to now.
        word_ids = new_word.long().cumsum(dim=1)
        # bpe_mask: value is set to True iff both bpes belong to the same word.
        bpe_mask = (word_ids.unsqueeze(dim=1) == word_ids.unsqueeze(dim=2))

        # word_mask
        word_lengths, _ = word_ids.max(dim=1)
        max_len = max(word_lengths)
        word_mask = get_length_mask(word_lengths, max_len)  # size: bs x wl

        # word2bpe is computed by getting all the row and column indices correctly.
        word_idx = (word_ids - 1)  # size: bs x l
        bpe_idx = get_range(l, 2, 1)  # size: 1 x l
        batch_i = get_range(bs, 2, 0)  # size: bs x 1
        word2bpe = get_zeros(bs, max_len, l)
        word2bpe[batch_i, word_idx, bpe_idx] = 1.0

        return GraphInfo(bpe_mask, word_mask, word_lengths, word2bpe)

    def get_graph_target(self,
                         data: Tensor,
                         lengths: Tensor,
                         lang: str,
                         split: str,
                         indices: List[int],
                         permutations: List[np.ndarray] = None,
                         keep: np.ndarray = None) -> GraphData:
        # NOTE(j_luo)  If for some reason the first one is <s> or </s>, we need to offset the indices.
        max_len = max(lengths)
        if self.ae_noise_graph_mode == 'change':
            assert permutations is not None and keep is not None

        offsets = ((data[0] == self.dico.eos_index) |
                   (data[0] == self.dico.bos_index)).long()
        graphs = self.graphs[(lang, split)]
        graphs = [graphs[i] for i in indices]
        bs = len(graphs)
        if len(offsets) != bs:
            raise RuntimeError('Something is terribly wrong.')

        ijkv = list()
        connected_vertices = get_zeros(len(graphs), max_len).bool()
        for batch_i, graph in enumerate(graphs):
            offset = offsets[batch_i].item()

            assert self.ae_noise_graph_mode != 'change', 'connected vertices cannot handle change for now.'
            vertices = np.asarray(graph.connected_vertices) + offset
            connected_vertices[batch_i, vertices] = True
            if offset > 0:
                connected_vertices[batch_i, 0] = True
            length = lengths[batch_i].item() - 1
            connected_vertices[batch_i, length] = True

            # Repeat the permutation and dropout processes and change the graph accordingly.
            if self.ae_noise_graph_mode == 'change':
                perm = permutations[batch_i].argsort()
                perm = np.arange(len(perm))[perm]
            for e in graph.edges:
                u = e.u + offset
                v = e.v + offset
                assert u < max_len and v < max_len
                if self.ae_noise_graph_mode == 'change':
                    u = perm[e.u]
                    v = perm[e.v]
                    if keep[u, batch_i] and keep[v, batch_i]:
                        ijkv.append((batch_i, u, v, e.t.value))
                else:
                    ijkv.append((batch_i, u, v, e.t.value))

        i, j, k, v = zip(*ijkv)
        v = get_tensor(v)
        edge_norm = get_zeros([bs, max_len, max_len])
        edge_type = get_zeros([bs, max_len, max_len]).long()
        # NOTE(j_luo) Edges are symmetric.
        edge_norm[i, j, k] = 1.0
        edge_norm[i, k, j] = 1.0
        edge_type[i, j, k] = v
        edge_type[i, k, j] = v
        edge_norm = edge_norm.view(-1)
        edge_type = edge_type.view(-1)
        return GraphData(None, None, edge_norm, edge_type, connected_vertices)

    def get_graph_loss(self, graph_data: GraphData, graph_target: GraphData,
                       lang: str) -> Tuple[Metric, Metric]:
        """
        Sizes for graph_data and graph_target:
                            graph_data      graph_target
            edge_norm:      E               E
            edge_type:      E x nr          E
        where E = bs x wl x wl.
        """
        # NOTE(j_luo) This determines whether it's an actual edge (in contrast to a padded edge) or not.
        assert len(graph_data.edge_norm) == len(graph_target.edge_norm)
        assert len(graph_data.edge_type) == len(graph_target.edge_type)

        edge_mask = graph_target.edge_norm

        edge_types_log_probs = (1e-8 + graph_data.edge_type).log()
        loss_edge_type = -edge_types_log_probs.gather(
            1, graph_target.edge_type.view(-1, 1)).view(-1)
        loss_edge_type = (loss_edge_type * edge_mask).sum()

        log_edge_norm = (graph_data.edge_norm + 1e-8).clamp(max=1.0).log()
        loss_edge_norm = -(log_edge_norm * edge_mask).sum()

        weight = edge_mask.sum()
        loss_edge_type = Metric(f'loss_edge_type_{lang}', loss_edge_type,
                                weight)
        loss_edge_norm = Metric(f'loss_edge_norm_{lang}', loss_edge_norm,
                                weight)
        return loss_edge_type, loss_edge_norm
Пример #20
0
class MetricLearningTrainer(BaseTrainer):

    add_argument('num_epochs', default=5, dtype=int, msg='number of epochs')

    def __init__(self, model: 'a', data_loader: 'a', num_epochs, learning_rate, check_interval, save_interval, log_dir):
        Trainer.__init__(self)
        self.tracker.add_track('epoch', update_fn='add', finish_when=num_epochs)

    def train(self,
              evaluator: evaluator.Evaluator,
              train_langs: List[str],
              dev_langs: List[str],
              fold_idx: int) -> Metrics:
        # Reset parameters.
        self._init_params(init_matrix=True, init_vector=True, init_higher_tensor=True)
        self.optimizer = optim.Adam(self.model.parameters(), self.learning_rate)
        # Main boy.
        accum_metrics = Metrics()
        best_mse = None
        while not self.tracker.is_finished:
            # Get data first.
            metrics = self.train_loop(train_langs)
            accum_metrics += metrics
            self.tracker.update()

            self.check_metrics(accum_metrics)

            if self.track % self.save_interval == 0:
                self.save(dev_langs, f'{fold_idx}.latest')
                dev_metrics = evaluator.evaluate(dev_langs)
                logging.info(dev_metrics.get_table(title='dev'))
                if best_mse is None or dev_metrics.mse.mean < best_mse:
                    best_mse = dev_metrics.mse.mean
                    logging.imp(f'Updated best mse score: {best_mse:.3f}')
                    self.save(dev_langs, f'{fold_idx}.best')
        return Metric('best_mse', best_mse, 1)

    def save(self, dev_langs: List[str], suffix: str):
        out_path = self.log_dir / f'saved.{suffix}'
        to_save = {
            'model': self.model.state_dict(),
            'g': g.state_dict(),
            'dev_langs': dev_langs,
        }
        torch.save(to_save, out_path)
        logging.imp(f'Model saved to {out_path}.')

    def reset(self):
        # HACK(j_luo) Need to improve trakcer api.
        self.tracker._attrs['epoch'] = 0

    def train_loop(self, train_langs: List[str]) -> Metrics:
        fold_data_loader = self.data_loader.select(train_langs, train_langs)
        metrics = Metrics()
        for batch_i, batch in enumerate(fold_data_loader):
            self.model.train()
            self.optimizer.zero_grad()

            output = self.model(batch)
            mse = (output - batch.dist) ** 2
            mse = Metric('mse', mse.sum(), len(batch))
            metrics += mse

            mse.mean.backward()
            self.optimizer.step()
        return metrics

    @property
    def track(self):
        return self.tracker.epoch
Пример #21
0
    with torch.no_grad():
        for i in range(len(dico)):
            idx = word2id.get(dico[i], None)
            if idx is None:
                continue
            n_found += 1
            model.embeddings.weight[i] = embeddings[idx].cuda()
            model.pred_layer.proj.weight[i] = embeddings[idx].cuda()
    logger.info("Pretrained %i/%i words (%.3f%%)." %
                (n_found, len(dico), 100. * n_found / len(dico)))


add_argument(
    'old_data_paths',
    nargs=2,
    dtype=str,
    msg=
    'paths (first src, second tgt) to the old data to extract the vocabularies.'
)


def build_model(params, dico):
    """
    Build model.
    """

    if params.encoder_only:
        # build
        model = TransformerModel(params,
                                 dico,
                                 is_encoder=True,
Пример #22
0
import torch

from arglib import add_argument

# IDEA(j_luo) typing!
LT = torch.LongTensor
FT = torch.FloatTensor
BT = torch.BoolTensor

add_argument('num_features',
             default=10,
             dtype=int,
             msg='total number of phonetic features')
add_argument('num_feature_groups',
             default=10,
             dtype=int,
             msg='total number of phonetic feature groups')
add_argument('dim',
             default=5,
             dtype=int,
             msg='dimensionality of feature embeddings')
add_argument('hidden_size', default=5, dtype=int, msg='hidden size')