def get_enc_mask(toks, structs, num_heads, mu_l, mu_r, lam, c_l, c_r): heads = torch.zeros(num_heads, dtype=torch.int, device=ut.get_device()) # [num_heads] head_segs = [ tree_utils.HEAD_CHILD_ID, tree_utils.HEAD_PARENT_ID, tree_utils.HEAD_SIB_ID, tree_utils.HEAD_OTHER_ID, tree_utils.HEAD_DESC_ID, tree_utils.HEAD_ANCE_ID, # tree_utils.HEAD_OTHER_ID, tree_utils.HEAD_ANCE_ID | tree_utils.HEAD_DESC_ID | tree_utils.HEAD_CHILD_ID | tree_utils.HEAD_PARENT_ID, ] # #tree_utils.HEAD_ANCE_ID | tree_utils.HEAD_PARENT_ID, #tree_utils.HEAD_DESC_ID | tree_utils.HEAD_CHILD_ID,] def seg(n): (n * num_heads) // len(head_segs) for i, head_seg in enumerate(head_segs): heads[seg(i):seg(i + 1)] = head_seg heads = heads | tree_utils.HEAD_BASE_IDS masks = tree_utils.get_enc_mask( toks, structs, num_heads) # [bsz, num_heads, src_len, src_len] masks.bitwise_and_(heads.unsqueeze(0).unsqueeze(2).unsqueeze(3)) return torch.logical_not(masks) #.transpose(2, 3)
def read_batches(self, read_handler, is_training=True, num_preload=ac.DEFAULT_NUM_PRELOAD, to_ids=False, with_trg=True): device = ut.get_device() while True: next_n_lines = list(itertools.islice(read_handler, num_preload)) if not next_n_lines: break src_inputs, src_seq_lengths, src_structs, trg_inputs, trg_seq_lengths = self.process_n_batches( next_n_lines, to_ids=to_ids, with_trg=with_trg) batches = self.prepare_batches(src_inputs, src_seq_lengths, src_structs, trg_inputs, trg_seq_lengths, is_training=is_training, with_trg=with_trg) for original_idxs, src_inputs, src_structs, trg_inputs, trg_target in zip( *batches): yield (original_idxs, torch.from_numpy(src_inputs).type( torch.long).to(device), src_structs, torch.from_numpy(trg_inputs).type( torch.long).to(device), torch.from_numpy(trg_target).type( torch.long).to(device))
def get_params(config): device = get_device() num_heads = config['num_enc_heads'] return dict( attL = torch.zeros(num_heads, device=device), attR = torch.zeros(num_heads, device=device) )
def get_decoder_mask(self, size): if self.decoder_mask is None or self.decoder_mask.size()[-1] < size: self.decoder_mask = torch.triu(torch.ones((1, 1, size, size), dtype=torch.bool, device=ut.get_device()), diagonal=1) return self.decoder_mask else: return self.decoder_mask[:, :, :size, :size]
def get_enc_mask(toks, structs, num_heads, attL, attR): bsz, src_len = toks.size() device = get_device() diagL = torch.diag(torch.ones(src_len, dtype=torch.float, device=device), diagonal=-1)[:src_len, :src_len] diagR = torch.diag(torch.ones(src_len, dtype=torch.float, device=device), diagonal=1)[:src_len, :src_len] maskL = (diagL.unsqueeze(0) * attL.reshape(-1, 1, 1)).unsqueeze(0) maskR = (diagR.unsqueeze(0) * attR.reshape(-1, 1, 1)).unsqueeze(0) return maskL + maskR, toks.type(torch.bool).unsqueeze(1).unsqueeze(2)
def get_params(config): embed_dim = config['embed_dim'] return dict( mu_l=tree_utils.init_tensor(embed_dim, embed_dim), mu_r=tree_utils.init_tensor(embed_dim, embed_dim), lam=tree_utils.init_tensor(embed_dim), c_l=tree_utils.init_tensor(), c_r=tree_utils.init_tensor(), self_attn_weights=torch.zeros(len(tree_utils.HEAD_IDS[1:]), config['num_enc_heads'], dtype=torch.float, device=ut.get_device()), )
def get_enc_mask(toks, structs, num_heads=1): bsz, src_len = toks.size() masks = torch.full((bsz, src_len, src_len), HEAD_PAD_ID, dtype=torch.int, device=ut.get_device()) for c in range(bsz): size = structs[c].size() flatten_mask_left(structs[c], 0, masks[c, :size, :size]) masks[c, size:, :] = HEAD_EXTRA_ID if num_heads == 1: return masks #.unsqueeze(1) else: return masks.unsqueeze(1).expand(-1, num_heads, -1, -1).clone()
def init_tensor(*size): device = ut.get_device() if len(size) == 0: t = torch.tensor([1.], device=device) elif len(size) == 1: t = torch.empty(*size, device=device) torch.nn.init.normal_(t, mean=0, std=size[0]**-0.5) elif len(size) == 2: t = torch.empty(*size, device=device) torch.nn.init.orthogonal_(t) else: assert False, "nmt.structs.tree_utils.init_tensor(*size) only implemented for len(size) == 0, 1, and 2, but got len({}) = {}".format( size, len(size)) return t
def __init__(self, config, load_from=None): super(Model, self).__init__() self.config = config self.struct = self.config['struct'] self.decoder_mask = None self.data_manager = DataManager(config, init_vocab=(not load_from)) if load_from: self.load_state_dict(torch.load(load_from, map_location=ut.get_device()), do_init=True) else: self.init_embeddings() self.init_model() self.add_struct_params() # dict where keys are data_ptrs to dicts of parameter options # see https://pytorch.org/docs/stable/optim.html#per-parameter-options self.parameter_attrs = {}
def __init__(self, args): super(Translator, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.logger = ut.get_logger(self.config['log_file']) self.num_preload = args.num_preload self.model_file = args.model_file if self.model_file is None: self.model_file = os.path.join(self.config['save_to'], self.config['model_name'] + '.pth') self.input_file = args.input_file if self.input_file is not None and not os.path.exists(self.input_file): raise FileNotFoundError( f'Input file does not exist: {self.input_file}') if not os.path.exists(self.model_file): raise FileNotFoundError( f'Model file does not exist: {self.model_file}') self.logger.info(f'Restore model from {self.model_file}') self.model = Model(self.config, load_from=self.model_file).to(ut.get_device()) if self.input_file: save_fp = os.path.join(self.config['save_to'], os.path.basename(self.input_file)) save_fp = save_fp.rstrip(self.model.data_manager.src_lang) save_fp = save_fp + self.model.data_manager.trg_lang self.best_output_fp = save_fp + '.best_trans' self.beam_output_fp = save_fp + '.beam_trans' open(self.best_output_fp, 'w').close() open(self.beam_output_fp, 'w').close() else: self.best_output_fp = self.beam_output_fp = None self.translate()
def get_params(config): return dict(self_attn_weights=torch.zeros(len(tree_utils.HEAD_IDS[1:]), config['num_enc_heads'], device=ut.get_device()), )
def __init__(self, args): super(Trainer, self).__init__() self.config = configurations.get_config( args.proto, getattr(configurations, args.proto), args.config_overrides) self.num_preload = args.num_preload self.lr = self.config['lr'] ut.remove_files_in_dir(self.config['save_to']) self.logger = ut.get_logger(self.config['log_file']) self.train_smooth_perps = [] self.train_true_perps = [] # For logging self.log_freq = self.config[ 'log_freq'] # log train stat every this-many batches self.log_train_loss = [] self.log_nll_loss = [] self.log_train_weights = [] self.log_grad_norms = [] self.total_batches = 0 # number of batches done for the whole training self.epoch_loss = 0. # total train loss for whole epoch self.epoch_nll_loss = 0. # total train loss for whole epoch self.epoch_weights = 0. # total train weights (# target words) for whole epoch self.epoch_time = 0. # total exec time for whole epoch, sounds like that tabloid # get model device = ut.get_device() self.model = Model(self.config).to(device) self.validator = Validator(self.config, self.model) self.validate_freq = self.config['validate_freq'] if self.validate_freq == 1: self.logger.info('Evaluate every ' + ( 'epoch' if self.config['val_per_epoch'] else 'batch')) else: self.logger.info(f'Evaluate every {self.validate_freq:,} ' + ( 'epochs' if self.config['val_per_epoch'] else 'batches')) # Estimated number of batches per epoch self.est_batches = max(self.model.data_manager.training_tok_counts ) // self.config['batch_size'] self.logger.info( f'Guessing around {self.est_batches:,} batches per epoch') param_count = sum( [numpy.prod(p.size()) for p in self.model.parameters()]) self.logger.info(f'Model has {int(param_count):,} parameters') # Set up parameter-specific options params = [] for p in self.model.parameters(): ptr = p.data_ptr() d = {'params': [p]} if ptr in self.model.parameter_attrs: attrs = self.model.parameter_attrs[ptr] for k in attrs: d[k] = attrs[k] params.append(d) self.optimizer = torch.optim.Adam(params, lr=self.lr, betas=(self.config['beta1'], self.config['beta2']), eps=self.config['epsilon'])
def init_embeddings(self): embed_dim = self.config['embed_dim'] tie_mode = self.config['tie_mode'] fix_norm = self.config['fix_norm'] max_src_len = self.config['max_src_length'] max_trg_len = self.config['max_trg_length'] device = ut.get_device() # get trg positonal embedding if not self.config['learned_pos_trg']: self.pos_embedding_trg = ut.get_position_embedding( embed_dim, max_trg_len) else: self.pos_embedding_trg = Parameter( torch.empty(max_trg_len, embed_dim, dtype=torch.float, device=device)) nn.init.normal_(self.pos_embedding_trg, mean=0, std=embed_dim**-0.5) # get word embeddings # TODO: src_vocab_mask is assigned but never used (?) #src_vocab_size, trg_vocab_size = ut.get_vocab_sizes(self.config) #self.src_vocab_mask, self.trg_vocab_mask = ut.get_vocab_masks(self.config, src_vocab_size, trg_vocab_size) #if tie_mode == ac.ALL_TIED: # src_vocab_size = trg_vocab_size = self.trg_vocab_mask.shape[0] self.src_vocab_mask = self.data_manager.vocab_masks[ self.data_manager.src_lang] self.trg_vocab_mask = self.data_manager.vocab_masks[ self.data_manager.trg_lang] src_vocab_size = self.src_vocab_mask.shape[0] trg_vocab_size = self.trg_vocab_mask.shape[0] self.out_bias = Parameter( torch.empty(trg_vocab_size, dtype=torch.float, device=device)) nn.init.constant_(self.out_bias, 0.) self.src_embedding = nn.Embedding(src_vocab_size, embed_dim) self.trg_embedding = nn.Embedding(trg_vocab_size, embed_dim) self.out_embedding = self.trg_embedding.weight if self.config['separate_embed_scales']: self.src_embed_scale = Parameter( torch.tensor([embed_dim**0.5], device=device)) self.trg_embed_scale = Parameter( torch.tensor([embed_dim**0.5], device=device)) else: self.src_embed_scale = self.trg_embed_scale = torch.tensor( [embed_dim**0.5], device=device) self.src_pos_embed_scale = torch.tensor([(embed_dim / 2)**0.5], device=device) self.trg_pos_embed_scale = torch.tensor( [1.], device=device ) # trg pos embedding already returns vector of norm sqrt(embed_dim/2) if self.config['learn_pos_scale']: self.src_pos_embed_scale = Parameter(self.src_pos_embed_scale) self.trg_pos_embed_scale = Parameter(self.trg_pos_embed_scale) if tie_mode == ac.ALL_TIED: self.src_embedding.weight = self.trg_embedding.weight if not fix_norm: nn.init.normal_(self.src_embedding.weight, mean=0, std=embed_dim**-0.5) nn.init.normal_(self.trg_embedding.weight, mean=0, std=embed_dim**-0.5) else: d = 0.01 # pure magic nn.init.uniform_(self.src_embedding.weight, a=-d, b=d) nn.init.uniform_(self.trg_embedding.weight, a=-d, b=d)
def get_pos_embedding(self, embed_dim): return SequenceStruct( torch.zeros(len(self.data), embed_dim, device=get_device()))