def get_enc_mask(toks, structs, num_heads, mu_l, mu_r, lam, c_l, c_r):
    heads = torch.zeros(num_heads, dtype=torch.int,
                        device=ut.get_device())  # [num_heads]
    head_segs = [
        tree_utils.HEAD_CHILD_ID,
        tree_utils.HEAD_PARENT_ID,
        tree_utils.HEAD_SIB_ID,
        tree_utils.HEAD_OTHER_ID,
        tree_utils.HEAD_DESC_ID,
        tree_utils.HEAD_ANCE_ID,
        #
        tree_utils.HEAD_OTHER_ID,
        tree_utils.HEAD_ANCE_ID | tree_utils.HEAD_DESC_ID
        | tree_utils.HEAD_CHILD_ID | tree_utils.HEAD_PARENT_ID,
    ]

    #
    #tree_utils.HEAD_ANCE_ID | tree_utils.HEAD_PARENT_ID,
    #tree_utils.HEAD_DESC_ID | tree_utils.HEAD_CHILD_ID,]
    def seg(n):
        (n * num_heads) // len(head_segs)

    for i, head_seg in enumerate(head_segs):
        heads[seg(i):seg(i + 1)] = head_seg
    heads = heads | tree_utils.HEAD_BASE_IDS
    masks = tree_utils.get_enc_mask(
        toks, structs, num_heads)  # [bsz, num_heads, src_len, src_len]
    masks.bitwise_and_(heads.unsqueeze(0).unsqueeze(2).unsqueeze(3))
    return torch.logical_not(masks)  #.transpose(2, 3)
Beispiel #2
0
 def read_batches(self,
                  read_handler,
                  is_training=True,
                  num_preload=ac.DEFAULT_NUM_PRELOAD,
                  to_ids=False,
                  with_trg=True):
     device = ut.get_device()
     while True:
         next_n_lines = list(itertools.islice(read_handler, num_preload))
         if not next_n_lines: break
         src_inputs, src_seq_lengths, src_structs, trg_inputs, trg_seq_lengths = self.process_n_batches(
             next_n_lines, to_ids=to_ids, with_trg=with_trg)
         batches = self.prepare_batches(src_inputs,
                                        src_seq_lengths,
                                        src_structs,
                                        trg_inputs,
                                        trg_seq_lengths,
                                        is_training=is_training,
                                        with_trg=with_trg)
         for original_idxs, src_inputs, src_structs, trg_inputs, trg_target in zip(
                 *batches):
             yield (original_idxs, torch.from_numpy(src_inputs).type(
                 torch.long).to(device), src_structs,
                    torch.from_numpy(trg_inputs).type(
                        torch.long).to(device),
                    torch.from_numpy(trg_target).type(
                        torch.long).to(device))
Beispiel #3
0
def get_params(config):
  device = get_device()
  num_heads = config['num_enc_heads']
  return dict(
    attL = torch.zeros(num_heads, device=device),
    attR = torch.zeros(num_heads, device=device)
  )
Beispiel #4
0
 def get_decoder_mask(self, size):
     if self.decoder_mask is None or self.decoder_mask.size()[-1] < size:
         self.decoder_mask = torch.triu(torch.ones((1, 1, size, size),
                                                   dtype=torch.bool,
                                                   device=ut.get_device()),
                                        diagonal=1)
         return self.decoder_mask
     else:
         return self.decoder_mask[:, :, :size, :size]
Beispiel #5
0
def get_enc_mask(toks, structs, num_heads, attL, attR):
  bsz, src_len = toks.size()
  device = get_device()

  diagL = torch.diag(torch.ones(src_len, dtype=torch.float, device=device), diagonal=-1)[:src_len, :src_len]
  diagR = torch.diag(torch.ones(src_len, dtype=torch.float, device=device), diagonal=1)[:src_len, :src_len]
  maskL = (diagL.unsqueeze(0) * attL.reshape(-1, 1, 1)).unsqueeze(0)
  maskR = (diagR.unsqueeze(0) * attR.reshape(-1, 1, 1)).unsqueeze(0)
  return maskL + maskR, toks.type(torch.bool).unsqueeze(1).unsqueeze(2)
Beispiel #6
0
def get_params(config):
    embed_dim = config['embed_dim']
    return dict(
        mu_l=tree_utils.init_tensor(embed_dim, embed_dim),
        mu_r=tree_utils.init_tensor(embed_dim, embed_dim),
        lam=tree_utils.init_tensor(embed_dim),
        c_l=tree_utils.init_tensor(),
        c_r=tree_utils.init_tensor(),
        self_attn_weights=torch.zeros(len(tree_utils.HEAD_IDS[1:]),
                                      config['num_enc_heads'],
                                      dtype=torch.float,
                                      device=ut.get_device()),
    )
Beispiel #7
0
def get_enc_mask(toks, structs, num_heads=1):
    bsz, src_len = toks.size()
    masks = torch.full((bsz, src_len, src_len),
                       HEAD_PAD_ID,
                       dtype=torch.int,
                       device=ut.get_device())

    for c in range(bsz):
        size = structs[c].size()
        flatten_mask_left(structs[c], 0, masks[c, :size, :size])
        masks[c, size:, :] = HEAD_EXTRA_ID

    if num_heads == 1: return masks  #.unsqueeze(1)
    else: return masks.unsqueeze(1).expand(-1, num_heads, -1, -1).clone()
Beispiel #8
0
def init_tensor(*size):
    device = ut.get_device()
    if len(size) == 0:
        t = torch.tensor([1.], device=device)
    elif len(size) == 1:
        t = torch.empty(*size, device=device)
        torch.nn.init.normal_(t, mean=0, std=size[0]**-0.5)
    elif len(size) == 2:
        t = torch.empty(*size, device=device)
        torch.nn.init.orthogonal_(t)
    else:
        assert False, "nmt.structs.tree_utils.init_tensor(*size) only implemented for len(size) == 0, 1, and 2, but got len({}) = {}".format(
            size, len(size))
    return t
Beispiel #9
0
    def __init__(self, config, load_from=None):
        super(Model, self).__init__()
        self.config = config
        self.struct = self.config['struct']
        self.decoder_mask = None
        self.data_manager = DataManager(config, init_vocab=(not load_from))

        if load_from:
            self.load_state_dict(torch.load(load_from,
                                            map_location=ut.get_device()),
                                 do_init=True)
        else:
            self.init_embeddings()
            self.init_model()
            self.add_struct_params()

        # dict where keys are data_ptrs to dicts of parameter options
        # see https://pytorch.org/docs/stable/optim.html#per-parameter-options
        self.parameter_attrs = {}
Beispiel #10
0
    def __init__(self, args):
        super(Translator, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.logger = ut.get_logger(self.config['log_file'])
        self.num_preload = args.num_preload

        self.model_file = args.model_file
        if self.model_file is None:
            self.model_file = os.path.join(self.config['save_to'],
                                           self.config['model_name'] + '.pth')

        self.input_file = args.input_file
        if self.input_file is not None and not os.path.exists(self.input_file):
            raise FileNotFoundError(
                f'Input file does not exist: {self.input_file}')
        if not os.path.exists(self.model_file):
            raise FileNotFoundError(
                f'Model file does not exist: {self.model_file}')

        self.logger.info(f'Restore model from {self.model_file}')
        self.model = Model(self.config,
                           load_from=self.model_file).to(ut.get_device())

        if self.input_file:
            save_fp = os.path.join(self.config['save_to'],
                                   os.path.basename(self.input_file))
            save_fp = save_fp.rstrip(self.model.data_manager.src_lang)
            save_fp = save_fp + self.model.data_manager.trg_lang
            self.best_output_fp = save_fp + '.best_trans'
            self.beam_output_fp = save_fp + '.beam_trans'
            open(self.best_output_fp, 'w').close()
            open(self.beam_output_fp, 'w').close()
        else:
            self.best_output_fp = self.beam_output_fp = None

        self.translate()
Beispiel #11
0
def get_params(config):
    return dict(self_attn_weights=torch.zeros(len(tree_utils.HEAD_IDS[1:]),
                                              config['num_enc_heads'],
                                              device=ut.get_device()), )
Beispiel #12
0
    def __init__(self, args):
        super(Trainer, self).__init__()
        self.config = configurations.get_config(
            args.proto, getattr(configurations, args.proto),
            args.config_overrides)
        self.num_preload = args.num_preload
        self.lr = self.config['lr']

        ut.remove_files_in_dir(self.config['save_to'])

        self.logger = ut.get_logger(self.config['log_file'])

        self.train_smooth_perps = []
        self.train_true_perps = []

        # For logging
        self.log_freq = self.config[
            'log_freq']  # log train stat every this-many batches
        self.log_train_loss = []
        self.log_nll_loss = []
        self.log_train_weights = []
        self.log_grad_norms = []
        self.total_batches = 0  # number of batches done for the whole training
        self.epoch_loss = 0.  # total train loss for whole epoch
        self.epoch_nll_loss = 0.  # total train loss for whole epoch
        self.epoch_weights = 0.  # total train weights (# target words) for whole epoch
        self.epoch_time = 0.  # total exec time for whole epoch, sounds like that tabloid

        # get model
        device = ut.get_device()
        self.model = Model(self.config).to(device)
        self.validator = Validator(self.config, self.model)

        self.validate_freq = self.config['validate_freq']
        if self.validate_freq == 1:
            self.logger.info('Evaluate every ' + (
                'epoch' if self.config['val_per_epoch'] else 'batch'))
        else:
            self.logger.info(f'Evaluate every {self.validate_freq:,} ' + (
                'epochs' if self.config['val_per_epoch'] else 'batches'))

        # Estimated number of batches per epoch
        self.est_batches = max(self.model.data_manager.training_tok_counts
                               ) // self.config['batch_size']
        self.logger.info(
            f'Guessing around {self.est_batches:,} batches per epoch')

        param_count = sum(
            [numpy.prod(p.size()) for p in self.model.parameters()])
        self.logger.info(f'Model has {int(param_count):,} parameters')

        # Set up parameter-specific options
        params = []
        for p in self.model.parameters():
            ptr = p.data_ptr()
            d = {'params': [p]}
            if ptr in self.model.parameter_attrs:
                attrs = self.model.parameter_attrs[ptr]
                for k in attrs:
                    d[k] = attrs[k]
            params.append(d)

        self.optimizer = torch.optim.Adam(params,
                                          lr=self.lr,
                                          betas=(self.config['beta1'],
                                                 self.config['beta2']),
                                          eps=self.config['epsilon'])
Beispiel #13
0
    def init_embeddings(self):
        embed_dim = self.config['embed_dim']
        tie_mode = self.config['tie_mode']
        fix_norm = self.config['fix_norm']
        max_src_len = self.config['max_src_length']
        max_trg_len = self.config['max_trg_length']

        device = ut.get_device()

        # get trg positonal embedding
        if not self.config['learned_pos_trg']:
            self.pos_embedding_trg = ut.get_position_embedding(
                embed_dim, max_trg_len)
        else:
            self.pos_embedding_trg = Parameter(
                torch.empty(max_trg_len,
                            embed_dim,
                            dtype=torch.float,
                            device=device))
            nn.init.normal_(self.pos_embedding_trg,
                            mean=0,
                            std=embed_dim**-0.5)

        # get word embeddings
        # TODO: src_vocab_mask is assigned but never used (?)
        #src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config)
        #self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks(self.config, src_vocab_size, trg_vocab_size)
        #if tie_mode == ac.ALL_TIED:
        #    src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0]
        self.src_vocab_mask = self.data_manager.vocab_masks[
            self.data_manager.src_lang]
        self.trg_vocab_mask = self.data_manager.vocab_masks[
            self.data_manager.trg_lang]
        src_vocab_size = self.src_vocab_mask.shape[0]
        trg_vocab_size = self.trg_vocab_mask.shape[0]

        self.out_bias = Parameter(
            torch.empty(trg_vocab_size, dtype=torch.float, device=device))
        nn.init.constant_(self.out_bias, 0.)

        self.src_embedding = nn.Embedding(src_vocab_size, embed_dim)
        self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim)
        self.out_embedding = self.trg_embedding.weight
        if self.config['separate_embed_scales']:
            self.src_embed_scale = Parameter(
                torch.tensor([embed_dim**0.5], device=device))
            self.trg_embed_scale = Parameter(
                torch.tensor([embed_dim**0.5], device=device))
        else:
            self.src_embed_scale = self.trg_embed_scale = torch.tensor(
                [embed_dim**0.5], device=device)

        self.src_pos_embed_scale = torch.tensor([(embed_dim / 2)**0.5],
                                                device=device)
        self.trg_pos_embed_scale = torch.tensor(
            [1.], device=device
        )  # trg pos embedding already returns vector of norm sqrt(embed_dim/2)
        if self.config['learn_pos_scale']:
            self.src_pos_embed_scale = Parameter(self.src_pos_embed_scale)
            self.trg_pos_embed_scale = Parameter(self.trg_pos_embed_scale)

        if tie_mode == ac.ALL_TIED:
            self.src_embedding.weight = self.trg_embedding.weight

        if not fix_norm:
            nn.init.normal_(self.src_embedding.weight,
                            mean=0,
                            std=embed_dim**-0.5)
            nn.init.normal_(self.trg_embedding.weight,
                            mean=0,
                            std=embed_dim**-0.5)
        else:
            d = 0.01  # pure magic
            nn.init.uniform_(self.src_embedding.weight, a=-d, b=d)
            nn.init.uniform_(self.trg_embedding.weight, a=-d, b=d)
Beispiel #14
0
 def get_pos_embedding(self, embed_dim):
     return SequenceStruct(
         torch.zeros(len(self.data), embed_dim, device=get_device()))