def validate(self, valid_iter): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() for batch in valid_iter: src = make_features(batch, 'src') _, src_lengths = batch.src tgt = make_features(batch, 'tgt') # F-prop through the model. outputs, attns = self.model(src, tgt, src_lengths) # Compute loss. batch_stats = self.valid_loss.monolithic_compute_loss( batch, outputs, attns) # Update statistics. stats.update(batch_stats) # Set model back to training mode. self.model.train() return stats
def validate(self, valid_iter): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() with torch.no_grad(): for batch in valid_iter: src = make_features(batch, 'src') src = src.transpose(0, 1).contiguous() # _, src_lengths = batch.src src_lengths = (torch.ones(batch.batch_size) * src.size(1)).long() tgt = make_features(batch, 'tgt') # F-prop through the model. outputs, attns = self.model(src, tgt, src_lengths) # Compute loss. batch_stats = self.valid_loss.monolithic_compute_loss( batch, outputs, attns) # Update statistics. stats.update(batch_stats) # Set model back to training mode. self.model.train() return stats
def validate(self, valid_iter, task_type='task'): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics(task_type=task_type) with torch.no_grad(): for batch in valid_iter: src = make_features(batch, 'src') _, src_lengths = batch.src if task_type == 'task': tgt = make_features(batch, 'tgt') else: tgt = make_features(batch, 'tgt2') # F-prop through the model. outputs, attns = self.model(src, tgt, src_lengths, task_type=task_type) # Compute loss. if task_type == 'task': batch_stats = self.valid_loss.monolithic_compute_loss( batch, outputs, attns) else: batch_stats = self.valid_loss2.monolithic_compute_loss( batch, outputs, attns) # Update statistics. stats.update(batch_stats) # Set model back to training mode. self.model.train() return stats
def _gradient_accumulation(self, batch, normalization, total_stats, report_stats): # 1. src = batch.src[0], xx * batch_size, 最后统一以<s>结尾? src = make_features(batch, 'src') # 2. src_lengths = batch.src[1], batch_size _, src_lengths = batch.src # 3. tgt_outer = batch.tgt, yy * batch_size, 包括开头的<s>2与结尾的</s>3以及可能出现的填充字符<blank>1 tgt = make_features(batch, 'tgt') # 目标句子长度 target_size = tgt.size(0) # 2. F-prop all but generator. # batch之间梯度无需累加 self.model.zero_grad() # outputs: (len, batch, dim) # attns: (len_tgt, batch, len_src) # logits: (len, batch_size, 2048) logits = self.model(src, tgt, src_lengths) # 3. Compute loss in shards for memory efficiency. # self.shard_size默认是2, attns没用上? batch_stats = self.train_loss.sharded_compute_loss( batch, logits, None, 0, target_size, self.shard_size, normalization) total_stats.update(batch_stats) report_stats.update(batch_stats) # 4. Update the parameters and statistics. self.optim.step() # If truncated, don't backprop fully. # TO CHECK # if dec_state is not None: # dec_state.detach() if self.model.decoder.state is not None: self.model.decoder.detach_state()
def _gradient_accumulation(self, true_batchs, normalization, total_stats, report_stats): if self.grad_accum_count > 1: self.model.zero_grad() for batch in true_batchs: target_size = batch.tgt.size(0) # Truncated BPTT: reminder not compatible with accum > 1 if self.trunc_size: trunc_size = self.trunc_size else: trunc_size = target_size # dec_state = None src = make_features(batch, 'src') # src 12 * 146 维度 _, src_lengths = batch.src tgt_outer = make_features(batch, 'tgt') structure1 = make_features(batch, 'structure1') structure1 = structure1.transpose(0, 1) structure1 = structure1.transpose(1, 2) structure2 = make_features(batch, 'structure2') structure2 = structure2.transpose(0, 1) structure2 = structure2.transpose(1, 2) structure3 = make_features(batch, 'structure3') structure3 = structure3.transpose(0, 1) structure3 = structure3.transpose(1, 2) structure4 = make_features(batch, 'structure4') structure4 = structure4.transpose(0, 1) structure4 = structure4.transpose(1, 2) structure5 = make_features(batch, 'structure5') structure5 = structure5.transpose(0, 1) structure5 = structure5.transpose(1, 2) # structure6 = make_features(batch, 'structure6') # structure6 = structure6.transpose(0, 1) # structure6 = structure6.transpose(1, 2) # # structure7 = make_features(batch, 'structure7') # structure7 = structure7.transpose(0, 1) # structure7 = structure7.transpose(1, 2) # # structure8 = make_features(batch, 'structure8') # structure8 = structure8.transpose(0, 1) # structure8 = structure8.transpose(1, 2) for j in range(0, target_size - 1, trunc_size): # 1. Create truncated target. tgt = tgt_outer[j:j + trunc_size] # 2. F-prop all but generator. if self.grad_accum_count == 1: self.model.zero_grad() outputs, attns = self.model(src, tgt, structure1, structure2, structure3, structure4, structure5, src_lengths) # 3. Compute loss in shards for memory efficiency. batch_stats = self.train_loss.sharded_compute_loss( batch, outputs, attns, j, trunc_size, self.shard_size, normalization) total_stats.update(batch_stats) report_stats.update(batch_stats) # 4. Update the parameters and statistics. if self.grad_accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [ p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None ] all_reduce_and_rescale_tensors(grads, float(1)) self.optim.step() # If truncated, don't backprop fully. # TO CHECK # if dec_state is not None: # dec_state.detach() if self.model.decoder.state is not None: self.model.decoder.detach_state() # in case of multi step gradient accumulation, # update only after accum batches if self.grad_accum_count > 1: if self.n_gpu > 1: grads = [ p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None ] all_reduce_and_rescale_tensors(grads, float(1)) self.optim.step()
def validate(self, valid_iter): """ Validate model. valid_iter: validate data iterator Returns: :obj:`nmt.Statistics`: validation loss statistics """ # Set model in validating mode. self.model.eval() stats = Statistics() for batch in valid_iter: src = make_features(batch, 'src') _, src_lengths = batch.src tgt = make_features(batch, 'tgt') structure1 = make_features(batch, 'structure1') structure1 = structure1.transpose(0, 1) structure1 = structure1.transpose(1, 2) structure2 = make_features(batch, 'structure2') structure2 = structure2.transpose(0, 1) structure2 = structure2.transpose(1, 2) structure3 = make_features(batch, 'structure3') structure3 = structure3.transpose(0, 1) structure3 = structure3.transpose(1, 2) structure4 = make_features(batch, 'structure4') structure4 = structure4.transpose(0, 1) structure4 = structure4.transpose(1, 2) structure5 = make_features(batch, 'structure5') structure5 = structure5.transpose(0, 1) structure5 = structure5.transpose(1, 2) # structure6 = make_features(batch, 'structure6') # structure6 = structure6.transpose(0, 1) # structure6 = structure6.transpose(1, 2) # # structure7 = make_features(batch, 'structure7') # structure7 = structure7.transpose(0, 1) # structure7 = structure7.transpose(1, 2) # # structure8 = make_features(batch, 'structure8') # structure8 = structure8.transpose(0, 1) # structure8 = structure8.transpose(1, 2) # F-prop through the model. outputs, attns = self.model(src, tgt, structure1, structure2, structure3, structure4, structure5, src_lengths) # Compute loss. batch_stats = self.valid_loss.monolithic_compute_loss( batch, outputs, attns) # Update statistics. stats.update(batch_stats) # Set model back to training mode. self.model.train() return stats
def translate_batch(self, batch): def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)} def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm): ''' Collect tensor parts associated to active instances. ''' _, *d_hs = beamed_tensor.size() n_curr_active_inst = len(curr_active_inst_idx) new_shape = (n_curr_active_inst * n_bm, *d_hs) beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1) beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx) beamed_tensor = beamed_tensor.view(*new_shape) return beamed_tensor def beam_decode_step( inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm): ''' Decode and update beam status, and then return active beam idx ''' # len_dec_seq: i (starting from 0) def prepare_beam_dec_seq(inst_dec_beams): dec_seq = [b.get_last_target_word() for b in inst_dec_beams if not b.done] # dec_seq: [(beam_size)] * batch_size dec_seq = torch.stack(dec_seq).to(self.device) # dec_seq: (batch_size, beam_size) dec_seq = dec_seq.view(1, -1) # dec_seq: (1, batch_size * beam_size) return dec_seq def predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq): # dec_seq: (1, batch_size * beam_size) dec_output, *_ = self.model.decoder(dec_seq, step=len_dec_seq) # dec_output: (1, batch_size * beam_size, hid_size) word_prob = self.model.generator(dec_output.squeeze(0)) # word_prob: (batch_size * beam_size, vocab_size) word_prob = word_prob.view(n_active_inst, n_bm, -1) # word_prob: (batch_size, beam_size, vocab_size) return word_prob def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map): active_inst_idx_list = [] select_indices_array = [] for inst_idx, inst_position in inst_idx_to_position_map.items(): is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position]) if not is_inst_complete: active_inst_idx_list += [inst_idx] select_indices_array.append(inst_beams[inst_idx].get_current_origin() + inst_position * n_bm) if len(select_indices_array) > 0: select_indices = torch.cat(select_indices_array) else: select_indices = None return active_inst_idx_list, select_indices n_active_inst = len(inst_idx_to_position_map) dec_seq = prepare_beam_dec_seq(inst_dec_beams) # dec_seq: (1, batch_size * beam_size) word_prob = predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq) # Update the beam with predicted word prob information and collect incomplete instances active_inst_idx_list, select_indices = collect_active_inst_idx_list( inst_dec_beams, word_prob, inst_idx_to_position_map) if select_indices is not None: assert len(active_inst_idx_list) > 0 self.model.decoder.map_state( lambda state, dim: state.index_select(dim, select_indices)) return active_inst_idx_list def collate_active_info( src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list): # Sentences which are still active are collected, # so the decoder will not run on completed sentences. n_prev_active_inst = len(inst_idx_to_position_map) active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list] active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device) active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm) active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm) active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) return active_src_seq, active_src_enc, active_inst_idx_to_position_map def collect_best_hypothesis_and_score(inst_dec_beams): hyps, scores = [], [] for inst_idx in range(len(inst_dec_beams)): hyp, score = inst_dec_beams[inst_idx].get_best_hypothesis() hyps.append(hyp) scores.append(score) return hyps, scores with torch.no_grad(): #-- Encode src_seq = make_features(batch, 'src') # src: (seq_len_src, batch_size) src_emb, src_enc, _ = self.model.encoder(src_seq) # src_emb: (seq_len_src, batch_size, emb_size) # src_end: (seq_len_src, batch_size, hid_size) self.model.decoder.init_state(src_seq, src_enc) src_len = src_seq.size(0) #-- Repeat data for beam search n_bm = self.beam_size n_inst = src_seq.size(1) self.model.decoder.map_state(lambda state, dim: tile(state, n_bm, dim=dim)) # src_enc: (seq_len_src, batch_size * beam_size, hid_size) #-- Prepare beams decode_length = src_len + self.decode_extra_length decode_min_length = 0 if self.decode_min_length >= 0: decode_min_length = src_len - self.decode_min_length if self.task_type == 'task': inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt_bos_id, eos_id=self.tgt_eos_id, device=self.device) for _ in range(n_inst)] else: inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt2_bos_id, eos_id=self.tgt2_eos_id, device=self.device) for _ in range(n_inst)] #-- Bookkeeping for active or not active_inst_idx_list = list(range(n_inst)) inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) #-- Decode for len_dec_seq in range(0, decode_length): active_inst_idx_list = beam_decode_step( inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm) if not active_inst_idx_list: break # all instances have finished their path to <EOS> inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list) batch_hyps, batch_scores = collect_best_hypothesis_and_score(inst_dec_beams) return batch_hyps, batch_scores
def reinforce_batch(self, batch): def get_inst_idx_to_tensor_position_map(inst_idx_list): ''' Indicate the position of an instance in a tensor. ''' return { inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list) } def reinforce_decode_step(len_dec_seq, inst_idx_to_position_map, dec_seq): ''' Decode and update beam status, and then return active beam idx ''' def predict_word(dec_seq, n_active_inst, len_dec_seq): """ :param dec_seq: 1*150 :param n_active_inst:30 :param n_bm: 5 :param len_dec_seq: :return: """ # dec_seq: (1, batch_size * beam_size) dec_output, *_ = self.model.decoder(dec_seq, step=len_dec_seq) # dec_output: (1, batch_size * beam_size, hid_size) word_prob = self.model.generator(dec_output.squeeze(0)) # word_prob: (batch_size * beam_size, vocab_size) # word_prob = word_prob.view(n_active_inst, -1) # word_prob: (batch_size, beam_size, vocab_size) return word_prob n_active_inst = len(inst_idx_to_position_map) # 30 # dec_seq = prepare_beam_dec_seq(inst_dec_beams) # dec_seq: (1, batch_size ) # in here ,we predict the word #word_prob batch_size*10 word_prob = predict_word(dec_seq, n_active_inst, len_dec_seq) # Update the beam with predicted word prob information and collect incomplete instances # active_inst_idx_list, select_indices = collect_active_inst_idx_list( # inst_dec_beams, word_prob, inst_idx_to_position_map) # if select_indices is not None: # assert len(active_inst_idx_list) > 0 # self.model.decoder.map_state( # lambda state, dim: state.index_select(dim, select_indices)) return word_prob # with torch.no_grad(): # -- Encode # src_seq:(batch_size,seq_len,dim) src_seq = make_features(batch, 'src') tgt = make_features(batch, 'tgt') src_seq = src_seq.transpose(0, 1).contiguous() # src: (seq_len_src, batch_size) src_emb, src_enc, _ = self.model.encoder(src_seq) # src_emb: (seq_len_src, batch_size, emb_size) # src_end: (seq_len_src, batch_size, hid_size) self.model.decoder.init_state(src_seq, src_enc) src_len = src_seq.size(0) # -- Repeat data for beam search # n_bm = self.beam_size batch_size = src_seq.size(1) # change the length of the src and src_enc ,five times batch_size (150) # self.model.decoder.map_state(lambda state, dim: tile(state, n_bm, dim=dim)) # src_enc: (seq_len_src, batch_size * beam_size, hid_size) # -- Prepare beams decode_length = self.decode_length # decode_min_length = 0 # if self.decode_min_length >= 0: # decode_min_length = src_len - self.decode_min_length # inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, # minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt_bos_id, # eos_id=self.tgt_eos_id, device=self.device) for _ in range(n_inst)] # -- Bookkeeping for active or not active_inst_idx_list = list(range(batch_size)) # [0,......batch_size] # change into {0:0,...idx:idx} inst_idx_to_position_map = get_inst_idx_to_tensor_position_map( active_inst_idx_list) dec_seq_greedy = Variable( torch.LongTensor(1, batch_size).fill_(self.tgt_bos_id)).cuda() dec_seq_mul = Variable( torch.LongTensor(1, batch_size).fill_(self.tgt_bos_id)).cuda() # -- Decode # all instances have finished their path to (<EOS>) no need <EOS> #first step use the greed method baseline outputs, sample_ids = [], [] for len_dec_seq in range(0, decode_length): output_prob = reinforce_decode_step(len_dec_seq, inst_idx_to_position_map, dec_seq_greedy) id = output_prob.max(1)[1] sample_ids += [id] outputs += [output_prob] dec_seq_greedy = id.unsqueeze(0) #second we use mutinol sample_ids = torch.stack(sample_ids).squeeze() outputs_mul, probs_mul = [], [] for len_dec_seq in range(0, decode_length): output_prob = reinforce_decode_step(len_dec_seq, inst_idx_to_position_map, dec_seq_mul) predicted = F.softmax(output_prob, 1).multinomial(1) one_hot = Variable(torch.zeros(output_prob.size())).cuda() one_hot.scatter_(1, predicted.long(), 1) prob = torch.masked_select(F.log_softmax(output_prob, 1), one_hot.type(torch.ByteTensor).cuda()) probs_mul += [prob] outputs_mul += [predicted] dec_seq_mul = predicted.transpose(0, 1) probs_mul = torch.stack(probs_mul).squeeze() outputs_mul = torch.stack( outputs_mul).squeeze() # [max_tgt_len, batch] return sample_ids, outputs_mul, probs_mul, tgt
def _gradient_accumulation(self, true_batchs, normalization, total_stats, report_stats, ratio=1.): if self.grad_accum_count > 1: self.model.zero_grad() for batch in true_batchs: # dec_state = None src = make_features(batch, 'src') _, src_lengths = batch.src tgt = make_features(batch, 'tgt') # reconstructor input stgt = make_features(batch, 'stgt') stgt = stgt.transpose(0, 1) stgt = stgt.transpose(1, 2) # if choose child randomly, make sequence different from the deep first traversal method choice = random.randint(0, stgt.size(0) - 1) stgt = stgt[choice][:-1] structure = make_features(batch, 'structure') structure = structure.transpose(0, 1) structure = structure.transpose(1, 2) # 2. F-prop all but generator. if self.grad_accum_count == 1: self.model.zero_grad() outputs, attns, s_outputs, s_attns = \ self.model(src, tgt, stgt, structure, src_lengths) # 3. Compute loss in shards for memory efficiency. batch_stats = self.train_loss.sharded_compute_loss( batch, (outputs, s_outputs), stgt, self.shard_size, normalization, ratio=ratio) total_stats.update(batch_stats) report_stats.update(batch_stats) # 4. Update the parameters and statistics. if self.grad_accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [ p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None ] all_reduce_and_rescale_tensors(grads, float(1)) self.optim.step() # If truncated, don't backprop fully. # TO CHECK # if dec_state is not None: # dec_state.detach() if self.model.decoder.state is not None: self.model.decoder.detach_state() # in case of multi step gradient accumulation, # update only after accum batches if self.grad_accum_count > 1: if self.n_gpu > 1: grads = [ p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None ] all_reduce_and_rescale_tensors(grads, float(1)) self.optim.step()
def _gradient_accumulation(self, true_batchs, normalization, total_stats, report_stats, ratio=0.15, ratio2=0.05): if self.grad_accum_count > 1: self.model.zero_grad() for batch in true_batchs: target_size = batch.tgt.size(0) # dec_state = None src = make_features(batch, 'src') _, src_lengths = batch.src tgt = make_features(batch, 'tgt') # reconstructor input stgt = make_features(batch, 'stgt') stgt = stgt.transpose(0, 1) stgt = stgt.transpose(1, 2) # randomly traversal sequence choice = random.randint(0, stgt.size(0) - 1) stgt = stgt[choice][:-1] structure = make_features(batch, 'structure') structure = structure.transpose(0, 1) structure = structure.transpose(1, 2) # bad code mask = make_features(batch, 'mask') mask = mask - 2 mask[mask <= 0] = 0 mask = mask.byte() relation = make_features(batch, 'relation') relation = relation.transpose(0, 1) relation = relation[relation != 1] # 2. F-prop all but generator. if self.grad_accum_count == 1: self.model.zero_grad() outputs, attns, s_outputs, s_attns, p, rels = \ self.model(src, tgt, stgt, structure, mask, src_lengths) # 3. Compute loss in shards for memory efficiency. batch_stats = self.train_loss.sharded_compute_loss( batch, (outputs, s_outputs), stgt, self.shard_size, normalization, ratio1=1-ratio, ratio2=ratio) if relation.size(0)>0: relation_loss = self.train_relation_loss(rels, relation) loss = (-p + relation_loss) / relation.size(0) loss = loss * ratio2 loss.backward() total_stats.update(batch_stats) report_stats.update(batch_stats) # 4. Update the parameters and statistics. if self.grad_accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step() # If truncated, don't backprop fully. # TO CHECK # if dec_state is not None: # dec_state.detach() if self.model.decoder.state is not None: self.model.decoder.detach_state() # in case of multi step gradient accumulation, # update only after accum batches if self.grad_accum_count > 1: if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step()
def _gradient_accumulation(self, true_batchs, normalization, total_stats, report_stats): if self.grad_accum_count > 1: self.model.zero_grad() for batch in true_batchs: target_size = batch.tgt.size(0) # Truncated BPTT: reminder not compatible with accum > 1 if self.trunc_size: trunc_size = self.trunc_size else: trunc_size = target_size # dec_state = None src = make_features(batch, 'src')#32*113 src = src.transpose(0, 1).contiguous() # _, src_lengths = batch.src #in here j ignored the length information and select a fix length src_lengths=(torch.ones(batch.batch_size)*src.size(1)).long() tgt_outer = make_features(batch, 'tgt') for j in range(0, target_size-1, trunc_size): # 1. Create truncated target. tgt = tgt_outer[j: j + trunc_size] # 2. F-prop all but generator. if self.grad_accum_count == 1: self.model.zero_grad() outputs, attns = \ self.model(src, tgt, src_lengths) # 3. Compute loss in shards for memory efficiency. batch_stats = self.train_loss.sharded_compute_loss( batch, outputs, attns, j, trunc_size, self.shard_size, normalization) total_stats.update(batch_stats) report_stats.update(batch_stats) # 4. Update the parameters and statistics. if self.grad_accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step() # If truncated, don't backprop fully. # TO CHECK # if dec_state is not None: # dec_state.detach() if self.model.decoder.state is not None: self.model.decoder.detach_state() # in case of multi step gradient accumulation, # update only after accum batches if self.grad_accum_count > 1: if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step()
def _gradient_accumulation(self, true_batchs, normalization, total_stats, report_stats, ratio): if self.grad_accum_count > 1: self.model.zero_grad() for batch in true_batchs: target_size = batch.tgt.size(0) # Truncated BPTT: reminder not compatible with accum > 1 if self.trunc_size: trunc_size = self.trunc_size else: trunc_size = target_size # dec_state = None src = make_features(batch, 'src') _, src_lengths = batch.src tgt_outer = make_features(batch, 'tgt') structure = make_features(batch, 'structure') structure = structure.transpose(0, 1) structure = structure.transpose(1, 2) # bad code mask = make_features(batch, 'mask') mask = mask - 2 mask[mask <= 0] = 0 mask = mask.byte() # ground truth label of biaffine relation relation = make_features(batch, 'relation') relation = relation.transpose(0, 1) relation = relation[relation != 1] for j in range(0, target_size - 1, trunc_size): # 1. Create truncated target. tgt = tgt_outer[j: j + trunc_size] # 2. F-prop all but generator. if self.grad_accum_count == 1: self.model.zero_grad() outputs, attns, p, rels = \ self.model(src, tgt, structure, mask, src_lengths) # 3. Compute loss in shards for memory efficiency. batch_stats = self.train_loss.sharded_compute_loss( batch, outputs, attns, j, trunc_size, self.shard_size, normalization, 1.) if relation.size(0)>0: # compute loss for label prediction relation_loss = self.train_relation_loss(rels, relation) # total loss of biaffine module loss = (-p + relation_loss) / relation.size(0) loss = loss * ratio loss.backward() total_stats.update(batch_stats) report_stats.update(batch_stats) # 4. Update the parameters and statistics. if self.grad_accum_count == 1: # Multi GPU gradient gather if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step() # If truncated, don't backprop fully. # TO CHECK # if dec_state is not None: # dec_state.detach() if self.model.decoder.state is not None: self.model.decoder.detach_state() # in case of multi step gradient accumulation, # update only after accum batches if self.grad_accum_count > 1: if self.n_gpu > 1: grads = [p.grad.data for p in self.model.parameters() if p.requires_grad and p.grad is not None] all_reduce_and_rescale_tensors( grads, float(1)) self.optim.step()