def validate(self, valid_iter):
        """ Validate model.
        valid_iter: validate data iterator
    Returns:
        :obj:`nmt.Statistics`: validation loss statistics
    """
        # Set model in validating mode.
        self.model.eval()

        stats = Statistics()

        for batch in valid_iter:
            src = make_features(batch, 'src')
            _, src_lengths = batch.src

            tgt = make_features(batch, 'tgt')

            # F-prop through the model.
            outputs, attns = self.model(src, tgt, src_lengths)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
Exemple #2
0
  def validate(self, valid_iter):
    """ Validate model.
        valid_iter: validate data iterator
    Returns:
        :obj:`nmt.Statistics`: validation loss statistics
    """
    # Set model in validating mode.
    self.model.eval()

    stats = Statistics()

    with torch.no_grad():
      for batch in valid_iter:
        src = make_features(batch, 'src')
        src = src.transpose(0, 1).contiguous()
        # _, src_lengths = batch.src
        src_lengths = (torch.ones(batch.batch_size) * src.size(1)).long()
        tgt = make_features(batch, 'tgt')

        # F-prop through the model.
        outputs, attns = self.model(src, tgt, src_lengths)

        # Compute loss.
        batch_stats = self.valid_loss.monolithic_compute_loss(
          batch, outputs, attns)

        # Update statistics.
        stats.update(batch_stats)

    # Set model back to training mode.
    self.model.train()

    return stats
    def validate(self, valid_iter, task_type='task'):
        """ Validate model.
        valid_iter: validate data iterator
    Returns:
        :obj:`nmt.Statistics`: validation loss statistics
    """
        # Set model in validating mode.
        self.model.eval()

        stats = Statistics(task_type=task_type)
        with torch.no_grad():
            for batch in valid_iter:
                src = make_features(batch, 'src')
                _, src_lengths = batch.src

                if task_type == 'task':
                    tgt = make_features(batch, 'tgt')
                else:
                    tgt = make_features(batch, 'tgt2')

                # F-prop through the model.
                outputs, attns = self.model(src,
                                            tgt,
                                            src_lengths,
                                            task_type=task_type)

                # Compute loss.
                if task_type == 'task':
                    batch_stats = self.valid_loss.monolithic_compute_loss(
                        batch, outputs, attns)
                else:
                    batch_stats = self.valid_loss2.monolithic_compute_loss(
                        batch, outputs, attns)

                # Update statistics.
                stats.update(batch_stats)

            # Set model back to training mode.
        self.model.train()

        return stats
Exemple #4
0
    def _gradient_accumulation(self, batch, normalization, total_stats,
                               report_stats):
        # 1. src = batch.src[0],  xx * batch_size, 最后统一以<s>结尾?
        src = make_features(batch, 'src')
        # 2. src_lengths = batch.src[1], batch_size
        _, src_lengths = batch.src
        # 3. tgt_outer = batch.tgt, yy * batch_size, 包括开头的<s>2与结尾的</s>3以及可能出现的填充字符<blank>1
        tgt = make_features(batch, 'tgt')

        # 目标句子长度
        target_size = tgt.size(0)

        # 2. F-prop all but generator.
        # batch之间梯度无需累加
        self.model.zero_grad()
        # outputs: (len, batch, dim)
        # attns: (len_tgt, batch, len_src)
        # logits: (len, batch_size, 2048)
        logits = self.model(src, tgt, src_lengths)

        # 3. Compute loss in shards for memory efficiency.
        # self.shard_size默认是2, attns没用上?
        batch_stats = self.train_loss.sharded_compute_loss(
            batch, logits, None, 0, target_size, self.shard_size,
            normalization)

        total_stats.update(batch_stats)
        report_stats.update(batch_stats)

        # 4. Update the parameters and statistics.
        self.optim.step()

        # If truncated, don't backprop fully.
        # TO CHECK
        # if dec_state is not None:
        #    dec_state.detach()
        if self.model.decoder.state is not None:
            self.model.decoder.detach_state()
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            target_size = batch.tgt.size(0)
            # Truncated BPTT: reminder not compatible with accum > 1
            if self.trunc_size:
                trunc_size = self.trunc_size
            else:
                trunc_size = target_size

            # dec_state = None
            src = make_features(batch, 'src')  # src 12 * 146  维度
            _, src_lengths = batch.src

            tgt_outer = make_features(batch, 'tgt')

            structure1 = make_features(batch, 'structure1')
            structure1 = structure1.transpose(0, 1)
            structure1 = structure1.transpose(1, 2)

            structure2 = make_features(batch, 'structure2')
            structure2 = structure2.transpose(0, 1)
            structure2 = structure2.transpose(1, 2)

            structure3 = make_features(batch, 'structure3')
            structure3 = structure3.transpose(0, 1)
            structure3 = structure3.transpose(1, 2)

            structure4 = make_features(batch, 'structure4')
            structure4 = structure4.transpose(0, 1)
            structure4 = structure4.transpose(1, 2)

            structure5 = make_features(batch, 'structure5')
            structure5 = structure5.transpose(0, 1)
            structure5 = structure5.transpose(1, 2)

            # structure6 = make_features(batch, 'structure6')
            # structure6 = structure6.transpose(0, 1)
            # structure6 = structure6.transpose(1, 2)
            #
            # structure7 = make_features(batch, 'structure7')
            # structure7 = structure7.transpose(0, 1)
            # structure7 = structure7.transpose(1, 2)
            #
            # structure8 = make_features(batch, 'structure8')
            # structure8 = structure8.transpose(0, 1)
            # structure8 = structure8.transpose(1, 2)

            for j in range(0, target_size - 1, trunc_size):
                # 1. Create truncated target.
                tgt = tgt_outer[j:j + trunc_size]

                # 2. F-prop all but generator.
                if self.grad_accum_count == 1:
                    self.model.zero_grad()
                outputs, attns = self.model(src, tgt, structure1, structure2,
                                            structure3, structure4, structure5,
                                            src_lengths)

                # 3. Compute loss in shards for memory efficiency.
                batch_stats = self.train_loss.sharded_compute_loss(
                    batch, outputs, attns, j, trunc_size, self.shard_size,
                    normalization)
                total_stats.update(batch_stats)
                report_stats.update(batch_stats)

                # 4. Update the parameters and statistics.
                if self.grad_accum_count == 1:
                    # Multi GPU gradient gather
                    if self.n_gpu > 1:
                        grads = [
                            p.grad.data for p in self.model.parameters()
                            if p.requires_grad and p.grad is not None
                        ]
                        all_reduce_and_rescale_tensors(grads, float(1))
                    self.optim.step()

                # If truncated, don't backprop fully.
                # TO CHECK
                # if dec_state is not None:
                #    dec_state.detach()
                if self.model.decoder.state is not None:
                    self.model.decoder.detach_state()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [
                    p.grad.data for p in self.model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))
            self.optim.step()
    def validate(self, valid_iter):
        """ Validate model.
        valid_iter: validate data iterator
    Returns:
        :obj:`nmt.Statistics`: validation loss statistics
    """
        # Set model in validating mode.
        self.model.eval()

        stats = Statistics()

        for batch in valid_iter:
            src = make_features(batch, 'src')
            _, src_lengths = batch.src

            tgt = make_features(batch, 'tgt')

            structure1 = make_features(batch, 'structure1')
            structure1 = structure1.transpose(0, 1)
            structure1 = structure1.transpose(1, 2)

            structure2 = make_features(batch, 'structure2')
            structure2 = structure2.transpose(0, 1)
            structure2 = structure2.transpose(1, 2)

            structure3 = make_features(batch, 'structure3')
            structure3 = structure3.transpose(0, 1)
            structure3 = structure3.transpose(1, 2)

            structure4 = make_features(batch, 'structure4')
            structure4 = structure4.transpose(0, 1)
            structure4 = structure4.transpose(1, 2)

            structure5 = make_features(batch, 'structure5')
            structure5 = structure5.transpose(0, 1)
            structure5 = structure5.transpose(1, 2)

            # structure6 = make_features(batch, 'structure6')
            # structure6 = structure6.transpose(0, 1)
            # structure6 = structure6.transpose(1, 2)
            #
            # structure7 = make_features(batch, 'structure7')
            # structure7 = structure7.transpose(0, 1)
            # structure7 = structure7.transpose(1, 2)
            #
            # structure8 = make_features(batch, 'structure8')
            # structure8 = structure8.transpose(0, 1)
            # structure8 = structure8.transpose(1, 2)

            # F-prop through the model.
            outputs, attns = self.model(src, tgt, structure1, structure2,
                                        structure3, structure4, structure5,
                                        src_lengths)

            # Compute loss.
            batch_stats = self.valid_loss.monolithic_compute_loss(
                batch, outputs, attns)

            # Update statistics.
            stats.update(batch_stats)

        # Set model back to training mode.
        self.model.train()

        return stats
Exemple #7
0
  def translate_batch(self, batch):
    def get_inst_idx_to_tensor_position_map(inst_idx_list):
      ''' Indicate the position of an instance in a tensor. '''
      return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}
    
    def collect_active_part(beamed_tensor, curr_active_inst_idx, n_prev_active_inst, n_bm):
      ''' Collect tensor parts associated to active instances. '''

      _, *d_hs = beamed_tensor.size()
      n_curr_active_inst = len(curr_active_inst_idx)
      new_shape = (n_curr_active_inst * n_bm, *d_hs)

      beamed_tensor = beamed_tensor.view(n_prev_active_inst, -1)
      beamed_tensor = beamed_tensor.index_select(0, curr_active_inst_idx)
      beamed_tensor = beamed_tensor.view(*new_shape)

      return beamed_tensor
    
    def beam_decode_step(
      inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm):
      ''' Decode and update beam status, and then return active beam idx '''
      # len_dec_seq: i (starting from 0)

      def prepare_beam_dec_seq(inst_dec_beams):
        dec_seq = [b.get_last_target_word() for b in inst_dec_beams if not b.done]
        # dec_seq: [(beam_size)] * batch_size
        dec_seq = torch.stack(dec_seq).to(self.device)
        # dec_seq: (batch_size, beam_size)
        dec_seq = dec_seq.view(1, -1)
        # dec_seq: (1, batch_size * beam_size)
        return dec_seq

      def predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq):
        # dec_seq: (1, batch_size * beam_size)
        dec_output, *_ = self.model.decoder(dec_seq, step=len_dec_seq)
        # dec_output: (1, batch_size * beam_size, hid_size)
        word_prob = self.model.generator(dec_output.squeeze(0))
        # word_prob: (batch_size * beam_size, vocab_size)
        
        word_prob = word_prob.view(n_active_inst, n_bm, -1)
        # word_prob: (batch_size, beam_size, vocab_size)

        return word_prob

      def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
        active_inst_idx_list = []
        select_indices_array = []
        for inst_idx, inst_position in inst_idx_to_position_map.items():
          is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
          if not is_inst_complete:
            active_inst_idx_list += [inst_idx]
            select_indices_array.append(inst_beams[inst_idx].get_current_origin() + inst_position * n_bm)
        if len(select_indices_array) > 0:
          select_indices = torch.cat(select_indices_array)
        else:
          select_indices = None
        return active_inst_idx_list, select_indices

      n_active_inst = len(inst_idx_to_position_map)

      dec_seq = prepare_beam_dec_seq(inst_dec_beams)
      # dec_seq: (1, batch_size * beam_size)
      word_prob = predict_word(dec_seq, n_active_inst, n_bm, len_dec_seq)

      # Update the beam with predicted word prob information and collect incomplete instances
      active_inst_idx_list, select_indices = collect_active_inst_idx_list(
        inst_dec_beams, word_prob, inst_idx_to_position_map)
      
      if select_indices is not None:
        assert len(active_inst_idx_list) > 0
        self.model.decoder.map_state(
            lambda state, dim: state.index_select(dim, select_indices))
      return active_inst_idx_list
    
    def collate_active_info(
        src_seq, src_enc, inst_idx_to_position_map, active_inst_idx_list):
      # Sentences which are still active are collected,
      # so the decoder will not run on completed sentences.
      n_prev_active_inst = len(inst_idx_to_position_map)
      active_inst_idx = [inst_idx_to_position_map[k] for k in active_inst_idx_list]
      active_inst_idx = torch.LongTensor(active_inst_idx).to(self.device)

      active_src_seq = collect_active_part(src_seq, active_inst_idx, n_prev_active_inst, n_bm)
      active_src_enc = collect_active_part(src_enc, active_inst_idx, n_prev_active_inst, n_bm)
      active_inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)

      return active_src_seq, active_src_enc, active_inst_idx_to_position_map

    def collect_best_hypothesis_and_score(inst_dec_beams):
      hyps, scores = [], []
      for inst_idx in range(len(inst_dec_beams)):
        hyp, score = inst_dec_beams[inst_idx].get_best_hypothesis()
        hyps.append(hyp)
        scores.append(score)
        
      return hyps, scores
    
    with torch.no_grad():
      #-- Encode
      src_seq = make_features(batch, 'src')
      # src: (seq_len_src, batch_size)
      src_emb, src_enc, _ = self.model.encoder(src_seq)
      # src_emb: (seq_len_src, batch_size, emb_size)
      # src_end: (seq_len_src, batch_size, hid_size)

      self.model.decoder.init_state(src_seq, src_enc)
      src_len = src_seq.size(0)
      
      #-- Repeat data for beam search
      n_bm = self.beam_size
      n_inst = src_seq.size(1)
      self.model.decoder.map_state(lambda state, dim: tile(state, n_bm, dim=dim))
      # src_enc: (seq_len_src, batch_size * beam_size, hid_size)
      
      #-- Prepare beams
      decode_length = src_len + self.decode_extra_length
      decode_min_length = 0
      if self.decode_min_length >= 0:
        decode_min_length = src_len - self.decode_min_length
      if self.task_type == 'task':
        inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt_bos_id, eos_id=self.tgt_eos_id, device=self.device) for _ in range(n_inst)]
      else:
        inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length, minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt2_bos_id, eos_id=self.tgt2_eos_id, device=self.device) for _ in range(n_inst)]

      #-- Bookkeeping for active or not
      active_inst_idx_list = list(range(n_inst))
      inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
      
      #-- Decode
      for len_dec_seq in range(0, decode_length):
        active_inst_idx_list = beam_decode_step(
          inst_dec_beams, len_dec_seq, inst_idx_to_position_map, n_bm)
        
        if not active_inst_idx_list:
          break  # all instances have finished their path to <EOS>

        inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
        
    batch_hyps, batch_scores = collect_best_hypothesis_and_score(inst_dec_beams)
    return batch_hyps, batch_scores
      
Exemple #8
0
    def reinforce_batch(self, batch):
        def get_inst_idx_to_tensor_position_map(inst_idx_list):
            ''' Indicate the position of an instance in a tensor. '''
            return {
                inst_idx: tensor_position
                for tensor_position, inst_idx in enumerate(inst_idx_list)
            }

        def reinforce_decode_step(len_dec_seq, inst_idx_to_position_map,
                                  dec_seq):
            ''' Decode and update beam status, and then return active beam idx '''
            def predict_word(dec_seq, n_active_inst, len_dec_seq):
                """

				:param dec_seq: 1*150
				:param n_active_inst:30
				:param n_bm: 5
				:param len_dec_seq:
				:return:
				"""
                # dec_seq: (1, batch_size * beam_size)
                dec_output, *_ = self.model.decoder(dec_seq, step=len_dec_seq)
                # dec_output: (1, batch_size * beam_size, hid_size)
                word_prob = self.model.generator(dec_output.squeeze(0))
                # word_prob: (batch_size * beam_size, vocab_size)
                # word_prob = word_prob.view(n_active_inst, -1)
                # word_prob: (batch_size, beam_size, vocab_size)

                return word_prob

            n_active_inst = len(inst_idx_to_position_map)  # 30

            # dec_seq = prepare_beam_dec_seq(inst_dec_beams)
            # dec_seq: (1, batch_size )
            # in here ,we predict the word
            #word_prob batch_size*10
            word_prob = predict_word(dec_seq, n_active_inst, len_dec_seq)

            # Update the beam with predicted word prob information and collect incomplete instances
            # active_inst_idx_list, select_indices = collect_active_inst_idx_list(
            # 	inst_dec_beams, word_prob, inst_idx_to_position_map)

            # if select_indices is not None:
            # 	assert len(active_inst_idx_list) > 0
            # 	self.model.decoder.map_state(
            # 		lambda state, dim: state.index_select(dim, select_indices))

            return word_prob

        # with torch.no_grad():
        # -- Encode
        # src_seq:(batch_size,seq_len,dim)
        src_seq = make_features(batch, 'src')
        tgt = make_features(batch, 'tgt')
        src_seq = src_seq.transpose(0, 1).contiguous()
        # src: (seq_len_src, batch_size)
        src_emb, src_enc, _ = self.model.encoder(src_seq)
        # src_emb: (seq_len_src, batch_size, emb_size)
        # src_end: (seq_len_src, batch_size, hid_size)
        self.model.decoder.init_state(src_seq, src_enc)
        src_len = src_seq.size(0)

        # -- Repeat data for beam search
        # n_bm = self.beam_size
        batch_size = src_seq.size(1)
        # change the length of the src and src_enc ,five times batch_size (150)
        # self.model.decoder.map_state(lambda state, dim: tile(state, n_bm, dim=dim))
        # src_enc: (seq_len_src, batch_size * beam_size, hid_size)

        # -- Prepare beams
        decode_length = self.decode_length
        # decode_min_length = 0
        # if self.decode_min_length >= 0:
        # decode_min_length = src_len - self.decode_min_length
        # inst_dec_beams = [Beam(n_bm, decode_length=decode_length, minimal_length=decode_min_length,
        # 					   minimal_relative_prob=self.minimal_relative_prob, bos_id=self.tgt_bos_id,
        # 					   eos_id=self.tgt_eos_id, device=self.device) for _ in range(n_inst)]

        # -- Bookkeeping for active or not
        active_inst_idx_list = list(range(batch_size))  # [0,......batch_size]
        # change into {0:0,...idx:idx}
        inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(
            active_inst_idx_list)

        dec_seq_greedy = Variable(
            torch.LongTensor(1, batch_size).fill_(self.tgt_bos_id)).cuda()
        dec_seq_mul = Variable(
            torch.LongTensor(1, batch_size).fill_(self.tgt_bos_id)).cuda()
        # -- Decode
        # all instances have finished their path to (<EOS>) no need <EOS>
        #first step use the greed method  baseline
        outputs, sample_ids = [], []
        for len_dec_seq in range(0, decode_length):
            output_prob = reinforce_decode_step(len_dec_seq,
                                                inst_idx_to_position_map,
                                                dec_seq_greedy)
            id = output_prob.max(1)[1]
            sample_ids += [id]
            outputs += [output_prob]
            dec_seq_greedy = id.unsqueeze(0)
        #second we use mutinol
        sample_ids = torch.stack(sample_ids).squeeze()
        outputs_mul, probs_mul = [], []
        for len_dec_seq in range(0, decode_length):
            output_prob = reinforce_decode_step(len_dec_seq,
                                                inst_idx_to_position_map,
                                                dec_seq_mul)
            predicted = F.softmax(output_prob, 1).multinomial(1)
            one_hot = Variable(torch.zeros(output_prob.size())).cuda()
            one_hot.scatter_(1, predicted.long(), 1)
            prob = torch.masked_select(F.log_softmax(output_prob, 1),
                                       one_hot.type(torch.ByteTensor).cuda())
            probs_mul += [prob]
            outputs_mul += [predicted]
            dec_seq_mul = predicted.transpose(0, 1)

        probs_mul = torch.stack(probs_mul).squeeze()
        outputs_mul = torch.stack(
            outputs_mul).squeeze()  # [max_tgt_len, batch]

        return sample_ids, outputs_mul, probs_mul, tgt
Exemple #9
0
    def _gradient_accumulation(self,
                               true_batchs,
                               normalization,
                               total_stats,
                               report_stats,
                               ratio=1.):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:

            # dec_state = None
            src = make_features(batch, 'src')
            _, src_lengths = batch.src

            tgt = make_features(batch, 'tgt')

            # reconstructor input
            stgt = make_features(batch, 'stgt')
            stgt = stgt.transpose(0, 1)
            stgt = stgt.transpose(1, 2)
            # if choose child randomly, make sequence different from the deep first traversal method
            choice = random.randint(0, stgt.size(0) - 1)
            stgt = stgt[choice][:-1]

            structure = make_features(batch, 'structure')
            structure = structure.transpose(0, 1)
            structure = structure.transpose(1, 2)

            # 2. F-prop all but generator.
            if self.grad_accum_count == 1:
                self.model.zero_grad()
            outputs, attns, s_outputs, s_attns = \
                self.model(src, tgt, stgt, structure, src_lengths)

            # 3. Compute loss in shards for memory efficiency.
            batch_stats = self.train_loss.sharded_compute_loss(
                batch, (outputs, s_outputs),
                stgt,
                self.shard_size,
                normalization,
                ratio=ratio)
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [
                        p.grad.data for p in self.model.parameters()
                        if p.requires_grad and p.grad is not None
                    ]
                    all_reduce_and_rescale_tensors(grads, float(1))
                self.optim.step()

            # If truncated, don't backprop fully.
            # TO CHECK
            # if dec_state is not None:
            #    dec_state.detach()
            if self.model.decoder.state is not None:
                self.model.decoder.detach_state()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [
                    p.grad.data for p in self.model.parameters()
                    if p.requires_grad and p.grad is not None
                ]
                all_reduce_and_rescale_tensors(grads, float(1))
            self.optim.step()
Exemple #10
0
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats, ratio=0.15, ratio2=0.05):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            target_size = batch.tgt.size(0)

            # dec_state = None
            src = make_features(batch, 'src')
            _, src_lengths = batch.src

            tgt = make_features(batch, 'tgt')
            # reconstructor input
            stgt = make_features(batch, 'stgt')
            stgt = stgt.transpose(0, 1)
            stgt = stgt.transpose(1, 2)
            # randomly traversal sequence
            choice = random.randint(0, stgt.size(0) - 1)
            stgt = stgt[choice][:-1]

            structure = make_features(batch, 'structure')
            structure = structure.transpose(0, 1)
            structure = structure.transpose(1, 2)
            # bad code
            mask = make_features(batch, 'mask')
            mask = mask - 2
            mask[mask <= 0] = 0
            mask = mask.byte()

            relation = make_features(batch, 'relation')
            relation = relation.transpose(0, 1)
            relation = relation[relation != 1]

            # 2. F-prop all but generator.
            if self.grad_accum_count == 1:
                self.model.zero_grad()
            outputs, attns, s_outputs, s_attns, p, rels = \
                self.model(src, tgt, stgt, structure, mask, src_lengths)

            # 3. Compute loss in shards for memory efficiency.
            batch_stats = self.train_loss.sharded_compute_loss(
                batch, (outputs, s_outputs), stgt, self.shard_size, normalization, ratio1=1-ratio,
                ratio2=ratio)
            if relation.size(0)>0:
                relation_loss = self.train_relation_loss(rels, relation)
                loss = (-p + relation_loss) / relation.size(0)
                loss = loss * ratio2
                loss.backward()
            total_stats.update(batch_stats)
            report_stats.update(batch_stats)

            # 4. Update the parameters and statistics.
            if self.grad_accum_count == 1:
                # Multi GPU gradient gather
                if self.n_gpu > 1:
                    grads = [p.grad.data for p in self.model.parameters()
                             if p.requires_grad
                             and p.grad is not None]
                    all_reduce_and_rescale_tensors(
                        grads, float(1))
                self.optim.step()

            # If truncated, don't backprop fully.
            # TO CHECK
            # if dec_state is not None:
            #    dec_state.detach()
            if self.model.decoder.state is not None:
                self.model.decoder.detach_state()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()
Exemple #11
0
  def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                             report_stats):

      if self.grad_accum_count > 1:
          self.model.zero_grad()

      for batch in true_batchs:
          target_size = batch.tgt.size(0)
          # Truncated BPTT: reminder not compatible with accum > 1
          if self.trunc_size:
              trunc_size = self.trunc_size
          else:
              trunc_size = target_size

          # dec_state = None
          src = make_features(batch, 'src')#32*113
          src = src.transpose(0, 1).contiguous()
          # _, src_lengths = batch.src
          #in here j ignored the length information and select a fix length
          src_lengths=(torch.ones(batch.batch_size)*src.size(1)).long()
          tgt_outer = make_features(batch, 'tgt')
          for j in range(0, target_size-1, trunc_size):
              # 1. Create truncated target.
              tgt = tgt_outer[j: j + trunc_size]

              # 2. F-prop all but generator.
              if self.grad_accum_count == 1:
                  self.model.zero_grad()
              outputs, attns = \
                  self.model(src, tgt, src_lengths)
              # 3. Compute loss in shards for memory efficiency.
              batch_stats = self.train_loss.sharded_compute_loss(
                  batch, outputs, attns, j,
                  trunc_size, self.shard_size, normalization)
              total_stats.update(batch_stats)
              report_stats.update(batch_stats)

              # 4. Update the parameters and statistics.
              if self.grad_accum_count == 1:
                  # Multi GPU gradient gather
                  if self.n_gpu > 1:
                      grads = [p.grad.data for p in self.model.parameters()
                               if p.requires_grad
                               and p.grad is not None]
                      all_reduce_and_rescale_tensors(
                          grads, float(1))
                  self.optim.step()

              # If truncated, don't backprop fully.
              # TO CHECK
              # if dec_state is not None:
                #  dec_state.detach()
              if self.model.decoder.state is not None:
                  self.model.decoder.detach_state()

      # in case of multi step gradient accumulation,
      # update only after accum batches
      if self.grad_accum_count > 1:
          if self.n_gpu > 1:
              grads = [p.grad.data for p in self.model.parameters()
                       if p.requires_grad
                       and p.grad is not None]
              all_reduce_and_rescale_tensors(
                  grads, float(1))
          self.optim.step()
    def _gradient_accumulation(self, true_batchs, normalization, total_stats,
                               report_stats, ratio):
        if self.grad_accum_count > 1:
            self.model.zero_grad()

        for batch in true_batchs:
            target_size = batch.tgt.size(0)
            # Truncated BPTT: reminder not compatible with accum > 1
            if self.trunc_size:
                trunc_size = self.trunc_size
            else:
                trunc_size = target_size

            # dec_state = None
            src = make_features(batch, 'src')
            _, src_lengths = batch.src

            tgt_outer = make_features(batch, 'tgt')

            structure = make_features(batch, 'structure')
            structure = structure.transpose(0, 1)
            structure = structure.transpose(1, 2)

            # bad code
            mask = make_features(batch, 'mask')
            mask = mask - 2
            mask[mask <= 0] = 0
            mask = mask.byte()

            # ground truth label of biaffine relation
            relation = make_features(batch, 'relation')
            relation = relation.transpose(0, 1)
            relation = relation[relation != 1]

            for j in range(0, target_size - 1, trunc_size):
                # 1. Create truncated target.
                tgt = tgt_outer[j: j + trunc_size]

                # 2. F-prop all but generator.
                if self.grad_accum_count == 1:
                    self.model.zero_grad()

                outputs, attns, p, rels = \
                    self.model(src, tgt, structure, mask, src_lengths)

                # 3. Compute loss in shards for memory efficiency.
                batch_stats = self.train_loss.sharded_compute_loss(
                    batch, outputs, attns, j,
                    trunc_size, self.shard_size, normalization, 1.)

                if relation.size(0)>0:
                    # compute loss for label prediction
                    relation_loss = self.train_relation_loss(rels, relation)
                    # total loss of biaffine module
                    loss = (-p + relation_loss) / relation.size(0)
                    loss = loss * ratio
                    loss.backward()

                total_stats.update(batch_stats)
                report_stats.update(batch_stats)

                # 4. Update the parameters and statistics.
                if self.grad_accum_count == 1:
                    # Multi GPU gradient gather
                    if self.n_gpu > 1:
                        grads = [p.grad.data for p in self.model.parameters()
                                 if p.requires_grad
                                 and p.grad is not None]
                        all_reduce_and_rescale_tensors(
                            grads, float(1))
                    self.optim.step()

                # If truncated, don't backprop fully.
                # TO CHECK
                # if dec_state is not None:
                #    dec_state.detach()
                if self.model.decoder.state is not None:
                    self.model.decoder.detach_state()

        # in case of multi step gradient accumulation,
        # update only after accum batches
        if self.grad_accum_count > 1:
            if self.n_gpu > 1:
                grads = [p.grad.data for p in self.model.parameters()
                         if p.requires_grad
                         and p.grad is not None]
                all_reduce_and_rescale_tensors(
                    grads, float(1))
            self.optim.step()