Example #1
0
    def analyze(self, log_probs, almt_distr, words, lost_lengths):
        self.clear_cache()
        self._sample(words)

        assert self._eff_max_length == len(log_probs)

        tl, nc, bs = log_probs.shape
        charset = get_charset(self.lang)
        assert nc == len(charset)

        # V x bs, or c_s x c_t -> bs x V
        valid_log_probs = self._eff_weight.matmul(log_probs.view(-1, bs)).t()

        sl = almt_distr.shape[-1]
        pos = get_tensor(torch.arange(sl).float(), requires_grad=False)
        mean_pos = (pos * almt_distr).sum(dim=-1)  # bs x tl
        mean_pos = torch.cat([get_zeros(bs, 1, requires_grad=False).fill_(-1.0), mean_pos],
                             dim=-1)
        reg_weight = lost_lengths.float().view(-1, 1) - 1.0 - mean_pos[:, :-1]
        reg_weight.clamp_(0.0, 1.0)
        rel_pos = mean_pos[:, 1:] - mean_pos[:, :-1]  # bs x tl
        rel_pos_diff = rel_pos - 1
        margin = rel_pos_diff != 0
        reg_loss = margin.float() * (rel_pos_diff ** 2)  # bs x tl
        reg_loss = (reg_loss * reg_weight).sum()

        return Map(reg_loss=reg_loss, valid_log_probs=valid_log_probs)
Example #2
0
 def process(self, word):
     ret = [copy.deepcopy(Map(self._feat_dict)) for _ in range(len(word))]
     for (i, c) in enumerate(word):
         if c in self._char2id:
             ret[i].update({'char': c})
         else:
             c_lower = c.lower()
             if c_lower in self._char2id:
                 ret[i].update({'char': c_lower})
                 ret[i].update({'capitalization': True})
             else:
                 ret[i].update({'char': ''})
     return ret
Example #3
0
def collate_fn(batch):
    words = _get_item('word', batch)
    forms = _get_item('form', batch)
    char_seqs = _get_item('char_seq', batch)
    id_seqs = _get_item('id_seq', batch)
    lengths, words, forms, char_seqs, id_seqs = sort_all(words, forms, char_seqs, id_seqs)
    lengths = get_tensor(lengths, dtype='l')
    # Trim the id_seqs.
    max_len = max(lengths).item()
    id_seqs = pad_to_dense(id_seqs, dtype='l')
    id_seqs = get_tensor(id_seqs[:, :max_len])

    lang = batch[0].lang
    return Map(
        words=words, forms=forms, char_seqs=char_seqs, id_seqs=id_seqs, lengths=lengths, lang=lang)
Example #4
0
 def __getitem__(self, idx):
     word = self._words[idx]
     return Map(word=word, form=word.form, lang=self.lang, char_seq=word.char_seq, id_seq=word.id_seq)
Example #5
0
 def entire_batch(self):
     """Return the entire dataset as a batch. This shold have a persistent order among the words."""
     return Map(known=self.datasets[self.known_lang].entire_batch, lost=self.datasets[self.lost_lang].entire_batch)
Example #6
0
 def __iter__(self):
     for known_batch in super().__iter__():
         lost_batch = self.datasets[self.lost_lang].entire_batch
         num_samples = len(known_batch.words)
         yield Map(lost=lost_batch, known=known_batch, num_samples=num_samples)
Example #7
0
def parse_args():
    """Define args here."""

    parser.add_argument('--num_rounds',
                        '-nr',
                        default=3,
                        dtype=int,
                        help='how many rounds of EM')
    parser.add_argument('--num_epochs_per_M_step',
                        '-nm',
                        default=5,
                        dtype=int,
                        help='how many epochs for each M step')
    parser.add_argument('--saved_path',
                        '-sp',
                        dtype=str,
                        help='path to the saved model (and other metadata)')
    parser.add_argument('--learning_rate',
                        '-lr',
                        dtype=float,
                        default=5e-3,
                        help='initial learning rate')
    parser.add_argument('--num_cognates',
                        '-nc',
                        dtype=int,
                        help='how many cognate pairs')
    parser.add_argument('--inc',
                        dtype=int,
                        default=50,
                        help='increment of cognate pairs after each round')
    parser.add_argument(
        '--warm_up_steps',
        '-wus',
        dtype=int,
        default=1,
        help='how many steps at the start of training without edit distance')
    parser.add_argument(
        '--capacity',
        default=(1, ),
        nargs='+',
        dtype=int,
        help='capacity for the edges. The first value will be used for E step.'
    )
    parser.add_argument('--save_all',
                        dtype=bool,
                        help='flag to save all models')
    parser.add_argument('--eval_interval',
                        '-ei',
                        default=250,
                        dtype=int,
                        help='evaluate once after this many steps')
    parser.add_argument('--check_interval',
                        '-ci',
                        default=50,
                        dtype=int,
                        help='check and print metrics after this many steps')
    parser.add_argument('--cog_path',
                        '-cp',
                        dtype=str,
                        help='path to the cognate file')
    parser.add_argument('--char_emb_dim',
                        '-ced',
                        default=250,
                        dtype=int,
                        help='dimensionality of character embeddings')
    parser.add_argument('--hidden_size',
                        '-hs',
                        default=250,
                        dtype=int,
                        help='hidden size')
    parser.add_argument('--num_layers',
                        '-nl',
                        default=1,
                        dtype=int,
                        help='number of layers for cipher model')
    parser.add_argument('--dropout',
                        default=0.5,
                        dtype=float,
                        help='dropout rate between layers')
    parser.add_argument('--universal_charset_size',
                        '-ucs',
                        default=50,
                        dtype=int,
                        help='size of the (universal) character inventory')
    parser.add_argument('--lost_lang',
                        '-l',
                        dtype=str,
                        help='lost language code')
    parser.add_argument('--known_lang',
                        '-k',
                        dtype=str,
                        help='known language code')
    parser.add_argument('--norms_or_ratios',
                        '-nor',
                        dtype=float,
                        nargs='+',
                        default=(1.0, 0.2),
                        help='norm or ratio values in control mode')
    parser.add_argument('--control_mode',
                        '-cm',
                        dtype=str,
                        default='relative',
                        help='norm control mode')
    parser.add_argument('--residual',
                        dtype=bool,
                        default=True,
                        help='flag to use residual connection')
    parser.add_argument('--reg_hyper',
                        default=1.0,
                        dtype=float,
                        help='hyperparameter for regularization')
    parser.add_argument('--batch_size', '-bs', dtype=int, help='batch size')
    parser.add_argument('--momentum',
                        default=0.25,
                        dtype=float,
                        help='momentum for flow')
    parser.add_argument('--gpu', '-g', dtype=str, help='which gpu to choose')
    parser.add_argument('--random', dtype=bool, help='random, ignore seed')
    parser.add_argument('--seed', dtype=int, default=1234, help='random seed')
    parser.add_argument('--log_level',
                        default='INFO',
                        dtype=str,
                        help='log level')
    parser.add_argument('--n_similar',
                        dtype=int,
                        help='number of most similar source tokens to keep')
    parser.add_cfg_registry(registry)
    args = Map(**parser.parse_args())

    if args.gpu is not None:
        torch.cuda.set_device(int(args.gpu))  # HACK
        os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu

    if not args.random:
        random.seed(args.seed)
        np.random.seed(args.seed)
        torch.manual_seed(args.seed)

    create_logger(filepath=args.log_dir + '/log', log_level=args.log_level)
    log_pp(pformat(args))