def __init__(self, d_in, d_outs, gen_func): super().__init__() aeq(len(d_outs), 2) self.generators = nn.ModuleDict() self.generators['base'] = Generator(d_in, d_outs[0], gen_func) self.generators['crosslingual'] = Generator(d_in, d_outs[1], gen_func) self.add_switch('generator', self.generators, 'crosslingual', 'base')
def __init__(self, d_in, d_hid, d_out, mode='residual'): super().__init__() self.event_hidden_layer = nn.Linear(d_in, d_hid) self.agent_hidden_layer = nn.Linear(d_in, d_hid) self.theme_hidden_layer = nn.Linear(d_in, d_hid) self.hidden = nn.Linear(d_hid * 3, d_out) assert mode in ['nonlinear', 'sum', 'residual'] self.mode = mode if self.mode == 'residual': aeq(d_in, d_out)
def _example_dict_iter(self, line, index): line = line.split() if self.line_truncate: line = line[:self.line_truncate] words, feats, n_feats = TextDataset.extract_text_features(line) example_dict = {self.side: words, "indices": index} if feats: # All examples must have same number of features. aeq(self.n_feats, n_feats) prefix = self.side + "_feat_" example_dict.update((prefix + str(j), f) for j, f in enumerate(feats)) return example_dict
def score(self, h_t, h_s): """ Args: h_t (FloatTensor): sequence of queries ``(batch, tgt_len, dim)`` h_s (FloatTensor): sequence of sources ``(batch, src_len, dim`` Returns: FloatTensor: raw attention scores (unnormalized) for each src index ``(batch, tgt_len, src_len)`` """ # Check input sizes src_batch, src_len, src_dim = h_s.size() tgt_batch, tgt_len, tgt_dim = h_t.size() aeq(src_batch, tgt_batch) aeq(src_dim, tgt_dim) aeq(self.dim, src_dim) if self.attn_type in ["general", "dot"]: if self.attn_type == "general": h_t_ = h_t.view(tgt_batch * tgt_len, tgt_dim) h_t_ = self.linear_in(h_t_) h_t = h_t_.view(tgt_batch, tgt_len, tgt_dim) h_s_ = h_s.transpose(1, 2) # (batch, t_len, d) x (batch, d, s_len) --> (batch, t_len, s_len) return torch.bmm(h_t, h_s_) else: dim = self.dim wq = self.linear_query(h_t.view(-1, dim)) wq = wq.view(tgt_batch, tgt_len, 1, dim) wq = wq.expand(tgt_batch, tgt_len, src_len, dim) uh = self.linear_context(h_s.contiguous().view(-1, dim)) uh = uh.view(src_batch, 1, src_len, dim) uh = uh.expand(src_batch, tgt_len, src_len, dim) # (batch, t_len, s_len, d) wquh = torch.tanh(wq + uh) return self.v(wquh.view(-1, dim)).view(tgt_batch, tgt_len, src_len)
def forward(self, source, memory_bank, memory_lengths=None, coverage=None): """ Args: source (FloatTensor): query vectors ``(batch, tgt_len, dim)`` memory_bank (FloatTensor): source vectors ``(batch, src_len, dim)`` memory_lengths (LongTensor): the source context lengths ``(batch,)`` coverage (FloatTensor): None (not supported yet) Returns: (FloatTensor, FloatTensor): * Computed vector ``(tgt_len, batch, dim)`` * Attention distribtutions for each query ``(tgt_len, batch, src_len)`` """ # one step input if source.dim() == 2: one_step = True source = source.unsqueeze(1) else: one_step = False batch, source_l, dim = memory_bank.size() batch_, target_l, dim_ = source.size() aeq(batch, batch_) aeq(dim, dim_) aeq(self.dim, dim) if coverage is not None: batch_, source_l_ = coverage.size() aeq(batch, batch_) aeq(source_l, source_l_) cover = coverage.view(-1).unsqueeze(1) memory_bank = memory_bank + self.linear_cover(cover).view_as( memory_bank) memory_bank = torch.tanh(memory_bank) # compute attention scores, as in Luong et al. align = self.score(source, memory_bank) # if memory_lengths is not None: # mask = sequence_mask(memory_lengths, max_len=align.size(-1)) # mask = mask.unsqueeze(1) # Make it broadcastable. # align.masked_fill_(1 - mask, -float('inf')) if self.mask is not None: align.data.masked_fill_(self.mask.view(-1, self.mask.size(-1)), -float('inf')) # Softmax or sparsemax to normalize attention weights align_vectors = F.softmax(align.view(batch * target_l, source_l), -1) align_vectors = align_vectors.view(batch, target_l, source_l) # each context vector c_t is the weighted average # over all the source hidden states c = torch.bmm(align_vectors, memory_bank) # concatenate concat_c = torch.cat([c, source], 2).view(batch * target_l, dim * 2) attn_h = self.linear_out(concat_c).view(batch, target_l, dim) if self.attn_type in ["general", "dot"]: attn_h = torch.tanh(attn_h) if one_step: attn_h = attn_h.squeeze(1) align_vectors = align_vectors.squeeze(1) if coverage is not None: covered = F.softmax( align.view(batch * target_l, source_l) / self.temp, -1) covered = covered.view(batch, target_l, source_l).squeeze(1) coverage = coverage - covered # Check output sizes batch_, dim_ = attn_h.size() aeq(batch, batch_) aeq(dim, dim_) batch_, source_l_ = align_vectors.size() aeq(batch, batch_) aeq(source_l, source_l_) else: attn_h = attn_h.transpose(0, 1).contiguous() align_vectors = align_vectors.transpose(0, 1).contiguous() # Check output sizes target_l_, batch_, dim_ = attn_h.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(dim, dim_) target_l_, batch_, source_l_ = align_vectors.size() aeq(target_l, target_l_) aeq(batch, batch_) aeq(source_l, source_l_) return attn_h, align_vectors, coverage
def _gradient_accumulation(self, true_batches, normalization, total_stats, report_stats): if self.accum_count > 1: self.optim.zero_grad() for k, batch in enumerate(true_batches): task = self._get_task(batch) if task.category == 'lm': target_size = batch.tgt_agent.size(0) else: target_size = batch.tgt.size(0) # Truncated BPTT: reminder not compatible with accum > 1 if self.trunc_size: trunc_size = self.trunc_size else: trunc_size = target_size if task.name == 'lm' or (task.name == 'crosslingual' and self.crosslingual == 'lm'): src_attr = 'src_event' elif task.name == 'crosslingual': src_attr = 'src_old' if batch.eat_format == 'combined' else 'src' elif batch.eat_format in ['old', 'new']: src_attr = 'src' else: src_attr = 'src_old' src, src_lengths = getattr(batch, src_attr) if isinstance(getattr(batch, src_attr), tuple) \ else (getattr(batch, src_attr), None) r_stats = report_stats[task.name] t_stats = total_stats[task.name] if src_lengths is not None: r_stats.n_src_words += src_lengths.sum().item() tgt_outer = batch.tgt bptt = False for j in range(0, target_size - 1, trunc_size): # 1. Create truncated target. tgt = tgt_outer[j:j + trunc_size] # 2. F-prop all but generator. if self.accum_count == 1: self.optim.zero_grad() # 2.5 Prepare for crosslingual mode if needed. outputs, attns = self.model(src, tgt, src_lengths, bptt=bptt, task=task) bptt = True # 2.9 Get appropriate loss function. if task.name == 'crosslingual': loss_func = self.crosslingual_train_loss elif task.name == 'lm': loss_func = self.lm_train_loss elif task.name == 'aux': loss_func = self.aux_train_loss else: loss_func = self.train_loss # 3. Compute loss. try: def loss_call(batch, outputs): loss, batch_stats = loss_func( batch, outputs, attns, normalization=normalization, shard_size=self.shard_size, trunc_start=j, trunc_size=trunc_size) return loss, batch_stats if task.category == 'lm': agent_preds = outputs['agent'] theme_preds = outputs['theme'] batch.agent_preds = agent_preds batch.theme_preds = theme_preds batch.tgt_backup = batch.tgt def update_loss_and_stats(tgt_attr, preds, loss, batch_stats): batch.tgt = getattr(batch, tgt_attr) this_loss, this_batch_stats = loss_call( batch, preds) if this_loss is not None: raise RuntimeError( 'loss is not properly updated from within this function.' ) batch_stats.update(this_batch_stats) loss = None batch_stats = onmt.utils.Statistics() update_loss_and_stats('tgt_agent', agent_preds, loss, batch_stats) # update_loss_and_stats('tgt_agent_mod', agent_preds, loss, batch_stats) # update_loss_and_stats('tgt_theme', theme_preds, loss, batch_stats) # update_loss_and_stats('tgt_theme_mod', theme_preds, loss, batch_stats) else: loss, batch_stats = loss_call(batch, outputs) if task.name == 'crosslingual' and self.almt_reg_hyper > 0.0: weight = self.model.encoder.embeddings.almt_layers[ 'mapping'].weight reg_loss = weight d1, d2 = weight.shape aeq(d1, d2) eye = torch.eye(d1).to(weight.device) reg_loss = ((weight @ weight.T - eye)**2).sum() reg_loss.backward() if loss is not None: self.optim.backward(loss) t_stats.update(batch_stats) r_stats.update(batch_stats) except Exception: traceback.print_exc() logger.info("At step %d, we removed a batch - accum %d", self.optim.training_step, k) raise # 4. Update the parameters and statistics. if self.accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [ p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None ] onmt.utils.distributed.all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step() # If truncated, don't backprop fully. # TO CHECK # if dec_state is not None: # dec_state.detach() if self.model.decoder.state is not None: self.model.decoder.detach_state() # in case of multi step gradient accumulation, # update only after accum batches if self.accum_count > 1: if self.n_gpu > 1: grads = [ p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None ] onmt.utils.distributed.all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step()
def main(opt): ArgumentParser.validate_train_opts(opt) ArgumentParser.update_model_opts(opt) ArgumentParser.validate_model_opts(opt) # Load checkpoint if we resume from a previous training. aux_vocab = None if opt.train_from: logger.info('Loading checkpoint from %s' % opt.train_from) checkpoint = torch.load(opt.train_from, map_location=lambda storage, loc: storage) logger.info('Loading vocab from checkpoint at %s.' % opt.train_from) vocab = checkpoint['vocab'] if opt.crosslingual: aux_vocab = checkpoint['aux_vocab'] elif opt.crosslingual: assert opt.crosslingual in ['old', 'lm'] vocab = torch.load(opt.data + '.vocab.pt') aux_vocab = torch.load(opt.aux_train_data + '.vocab.pt') else: vocab = torch.load(opt.data + '.vocab.pt') # check for code where vocab is saved instead of fields # (in the future this will be done in a smarter way) def get_fields(vocab): if old_style_vocab(vocab): return load_old_vocab(vocab, opt.model_type, dynamic_dict=opt.copy_attn) else: return vocab fields = get_fields(vocab) aux_fields = None if opt.crosslingual: aux_fields = get_fields(aux_vocab) if opt.crosslingual: if opt.crosslingual == 'old': aeq(len(opt.eat_formats), 3) fields_info = [ ('train', fields, 'data', Eat2PlainMonoTask, 'base', opt.eat_formats[0]), ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask, 'aux', opt.eat_formats[1]), ('train', aux_fields, 'aux_train_data', Eat2PlainCrosslingualTask, 'crosslingual', opt.eat_format[2]) ] else: aeq(len(opt.eat_formats), 4) fields_info = [ ('train', fields, 'data', Eat2PlainMonoTask, 'base', opt.eat_formats[0]), ('train', fields, 'data', EatLMMonoTask, 'lm', opt.eat_formats[1]), ('train', aux_fields, 'aux_train_data', Eat2PlainAuxMonoTask, 'aux', opt.eat_formats[2]), ('train', aux_fields, 'aux_train_data', EatLMCrosslingualTask, 'crosslingual', opt.eat_formats[3]) ] train_iter = build_crosslingual_dataset_iter(fields_info, opt) elif len(opt.data_ids) > 1: train_shards = [] for train_id in opt.data_ids: shard_base = "train_" + train_id train_shards.append(shard_base) train_iter = build_dataset_iter_multiple(train_shards, fields, opt) else: if opt.data_ids[0] is not None: shard_base = "train_" + opt.data_ids[0] else: shard_base = "train" train_iter = build_dataset_iter(shard_base, fields, opt) nb_gpu = len(opt.gpu_ranks) if opt.world_size > 1: queues = [] mp = torch.multiprocessing.get_context('spawn') semaphore = mp.Semaphore(opt.world_size * opt.queue_size) # Create a thread to listen for errors in the child processes. error_queue = mp.SimpleQueue() error_handler = ErrorHandler(error_queue) # Train with multiprocessing. procs = [] for device_id in range(nb_gpu): q = mp.Queue(opt.queue_size) queues += [q] procs.append( mp.Process(target=run, args=(opt, device_id, error_queue, q, semaphore), daemon=True)) procs[device_id].start() logger.info(" Starting process pid: %d " % procs[device_id].pid) error_handler.add_child(procs[device_id].pid) producer = mp.Process(target=batch_producer, args=( train_iter, queues, semaphore, opt, ), daemon=True) producer.start() error_handler.add_child(producer.pid) for p in procs: p.join() producer.terminate() else: device_id = 0 if nb_gpu == 1 else -1 # NOTE Only pass train_iter in my crosslingual mode. train_iter = train_iter if opt.crosslingual else None passed_fields = { 'main': fields, 'crosslingual': aux_fields } if opt.crosslingual else None single_main(opt, device_id, train_iter=train_iter, passed_fields=passed_fields)