Beispiel #1
0
def decode_sequence(model,
                    char_dictionary,
                    max_len,
                    src_seq,
                    src_mask,
                    src_len,
                    pad=PAD_ID_CHAR,
                    target_seq_gold=None,
                    input_word=None,
                    use_gpu=False,
                    showing_attention=False,
                    single_sequence=False,
                    eval_time=True,
                    verbose=2,
                    timing=False):

    #eval_time alays True for now
    printing("EVAL TIME is {}",
             var=eval_time,
             verbose=verbose,
             verbose_level=2)
    output_seq = pad * np.ones(src_seq.size(), dtype=np.int64)
    # we start with the _START symbol

    output_seq[:, :, 0] = src_seq[:, :, 0]  #CHAR_START_ID
    src_text_ls = []
    target_seq_gold_ls = [] if target_seq_gold is not None else None
    output_mask = np.ones(src_mask.size(), dtype=np.int64)
    output_mask[:, :, 1:] = 0
    output_len = Variable(torch.from_numpy(
        np.ones((src_seq.size(0), src_seq.size(1), 1), dtype=np.int64)),
                          requires_grad=False)
    output_mask = Variable(torch.from_numpy(output_mask), requires_grad=False)
    output_seq = Variable(torch.from_numpy(output_seq), requires_grad=False)
    printing("Data Start source {} {} ",
             var=(src_seq, src_seq.size()),
             verbose=verbose,
             verbose_level=5)
    output_str = True
    printing("WARNING : output_str = True hardcoded (decode_sequence)",
             verbose=verbose,
             verbose_level=2)
    printing("Data output sizes ",
             var=(output_seq.size(), output_mask.size()),
             verbose=verbose,
             verbose_level=6)
    start_decode_sequence = time.time() if timing else None

    for step, char_decode in enumerate(range(2, max_len)):
        if use_gpu:
            src_seq = src_seq.cuda()
            output_seq = output_seq.cuda()
            src_len = src_len.cuda()
            output_len = output_len.cuda()
        start = time.time() if timing else None
        output_len = (src_len[:, :, 0] != 0).unsqueeze(dim=2) * (char_decode -
                                                                 1)
        printing("DECODER step {} output len {} ",
                 var=(step, output_len),
                 verbose=verbose,
                 verbose_level=4)

        decoding_states, word_pred, pos_pred, norm_not_norm, edit_pred, attention, _ \
            = model.forward(input_seq=src_seq,
                            output_seq=output_seq,
                            input_word_len=src_len,
                            output_word_len=output_len,
                            word_embed_input=input_word)

        time_forward, start = get_timing(start)
        # [batch, seq_len, V]

        pred_norm_not_norm = norm_not_norm.argmax(
            dim=-1) if norm_not_norm is not None else None
        scores = model.generator.forward(x=decoding_states)

        time_generate, start = get_timing(start)
        # each time step predict the most likely
        # len
        # output_len defined based on src_len to remove empty words
        #output_len[:] = char_decode # before debugging
        # mask
        output_mask = np.ones(src_seq.size(), dtype=np.int64)
        output_mask[:, char_decode:] = 0
        # new seq
        predictions = scores.argmax(dim=-1)

        time_argmax_printing, start = get_timing(start)
        if verbose >= 4:
            # .size() takes some time
            printing("Prediction size {} ",
                     var=(predictions.size()),
                     verbose=verbose,
                     verbose_level=0)
            printing("SCORES {} ",
                     var=[str(scores)],
                     verbose=verbose,
                     verbose_level=0)
            printing("Prediction {} ",
                     var=[predictions],
                     verbose=verbose,
                     verbose_level=0)

            printing(
                "scores: (1st batch)  {} scores sized  {} \n predicion size {} prediction {} ",
                var=[
                    scores[0, :, :, :],
                    scores.size(),
                    predictions.size(),
                    predictions[0, :, -1],
                ],
                verbose=verbose,
                verbose_level=0)
        time_printing, start = get_timing(start)

        output_seq = output_seq[:, :scores.size(1), :]
        time_output_seq, start = get_timing(start)
        if pred_norm_not_norm is not None:
            pred_norm_not_norm = pred_norm_not_norm[:, :scores.size(
                1)]  # followign what's done above

        output_seq[:, :, char_decode - 1] = predictions[:, :, -1]

        if verbose >= 4:
            sequence = [
                " ".join([
                    char_dictionary.get_instance(output_seq[sent, word_ind,
                                                            char_i])
                    for char_i in range(max_len)
                ]) + "|sent-{}|".format(sent)
                for sent in range(output_seq.size(0))
                for word_ind in range(output_seq.size(1))
            ]
        else:
            sequence = []

        printing("Decoding step {} decoded target {} ",
                 var=(step, sequence),
                 verbose=verbose,
                 verbose_level=4)
        time_sequence_text, start = get_timing(start)
        printing("DECODING scores {}",
                 var=[scores[0]],
                 verbose=verbose,
                 verbose_level=4)
        printing("DECODING decoding_states {}",
                 var=[decoding_states[0]],
                 verbose=verbose,
                 verbose_level=4)

        if eval_time:
            # at test time : if all prediction in the batch are whether PAD symbol or END symbol : we break
            if ((predictions[:, :, -1] == PAD_ID_CHAR) +
                (predictions[:, :, -1] == CHAR_END_ID)).all():
                printing(
                    "PREDICTION IS ONLY PAD or END SYMBOL SO BREAKING DECODING",
                    verbose=verbose,
                    verbose_level=1)
                break
    # no need to do that in the loop
    print(
        "WARNING : shfited output sequence of one character not to output START token"
    )
    pred_word_count, text_decoded, decoded_ls = output_text_(
        output_seq[:, :, 1:],
        char_dictionary,
        single_sequence=single_sequence,
        output_str=output_str,
        output_len=output_len,
        last=(char_decode == (max_len - 1)),
        showing_attention=showing_attention,
        debug=False)

    time_output_text, start = get_timing(start)

    time_decoding_all_seq, start = get_timing(start_decode_sequence)
    printing("PREDICTION : array text {} ",
             var=[text_decoded],
             verbose=verbose,
             verbose_level=0)
    src_word_count, src_text, src_all_ls = output_text_(
        src_seq,
        char_dictionary,
        single_sequence=single_sequence,
        showing_attention=showing_attention,
        output_str=output_str)
    printing("SOURCE  : array text {} ",
             var=[src_text],
             verbose=verbose,
             verbose_level=0)
    src_text_ls.extend(src_text)
    if target_seq_gold is not None:
        target_word_count, target_text, _ = output_text_(
            target_seq_gold,
            char_dictionary,
            showing_attention=showing_attention,
            single_sequence=single_sequence,
            output_str=output_str)
        target_seq_gold_ls.extend(target_text)
        printing("GOLD : array text {} ",
                 var=[target_text],
                 verbose=verbose,
                 verbose_level=0)
    else:
        target_word_count = None
    if single_sequence:
        if model.decoder.attn_layer is not None:
            attention = attention[0]
        if pred_norm_not_norm is not None:
            pred_norm_not_norm = pred_norm_not_norm[0]
    if timing:
        print("DECODING TIME : {}".format(
            OrderedDict([("time_decoding_all_seq", time_decoding_all_seq),
                         ("time_forward", time_forward),
                         ("time_generate", time_generate),
                         ("time_argmax_printing", time_argmax_printing),
                         ("time_printing", time_printing),
                         ("time_output_seq", time_output_seq),
                         ("time_sequence_text", time_sequence_text),
                         ("time_output_text", time_output_text),
                         ("time_decoding_all_seq", time_decoding_all_seq)])))

    return (text_decoded, src_text_ls, target_seq_gold_ls, None), \
           {
           "src_word_count": src_word_count,
           "target_word_count": target_word_count,
           "pred_word_count": pred_word_count},\
           (attention, src_all_ls,), \
           (pred_norm_not_norm, output_seq, src_seq, target_seq_gold)
Beispiel #2
0
    def forward(self, output, conditioning, output_word_len,
                char_seq_hidden_encoder=None,
                word_src_sizes=None, proportion_pred_train=None,
                sent_len_max_source=None, verbose=0):

        start = time.time() if self.timing else None
        _output_word_len = output_word_len.clone()
        clone_len, start = get_timing(start)
        # handle sentence that take the all sequence ()
        printing("TARGET SIZE : output_word_len length (before 0 last) : size {} data {} ", var=(_output_word_len.size(),_output_word_len), verbose=verbose,
                 verbose_level=4)
        printing("TARGET : output  (before 0 last) : size {}", var=[output.size()], verbose=verbose, verbose_level=3)
        printing("TARGET : output  (before 0 last) :  data {} ", var=[output], verbose=verbose, verbose_level=5)
        _output_word_len[:, -1, :] = 0
        # when input_word_len is 0 means we reached end of sentence
        # TODO : WARNING : is +1 required : as sent with 1 ? WHY ALWAYS IS NOT WORKING
        sent_len = torch.Tensor(np.argmin(np.array(_output_word_len), axis=1)).long()  ## PYTORCH 1.0 (or O.4)
        if _output_word_len.is_cuda:
            sent_len = sent_len.cuda()
        #sent_len = torch.argmin(_output_word_len, dim=1) ## PYTORCH WARNING : THEY MIGH BE A PROBLEM HERE
        # WARNING : forcint sent_len to be one
        if (sent_len == 0).any() and False:
            printing("WARNING : WE ARE FORCING SENT_LEN in the SOURCE SIDE", verbose=verbose, verbose_level=3)
            sent_len[sent_len == 0] += 1
        # as encoder side : we handle words that take the all sequnence
        sent_len += (output_word_len[:, -1, :] != 0).long()
        # sort batch at the sentence length
        sent_len, perm_idx_input_sent = sent_len.squeeze().sort(0, descending=True)
        argmin_squeeze, start = get_timing(start)
        inverse_perm_idx_input_sent = torch.from_numpy(np.argsort(perm_idx_input_sent.cpu().numpy()))
        sorting, start = get_timing(start)
        # [batch x sent_len , dim hidden word level] # this remove empty words
        #reorder so that it aligns with input

        try:
            packed_char_vecs_output = pack_padded_sequence(output[perm_idx_input_sent, :, :],
                                                           sent_len.cpu().numpy(), batch_first=True)
        except:
            print("EXCEPT DECODER PACKING", [perm_idx_input_sent])
            if len(perm_idx_input_sent.size()) == 0:
                perm_idx_input_sent = [perm_idx_input_sent]
                inverse_perm_idx_input_sent = [inverse_perm_idx_input_sent]
                sent_len = sent_len.view(-1)

            packed_char_vecs_output = pack_padded_sequence(output[perm_idx_input_sent, :, :],
                                                           sent_len.cpu().numpy(), batch_first=True)
        # unpacked for computing the word level representation
        #packed_char_vecs_output = pack_padded_sequence(output[perm_idx_input_sent, :, :],
        #                                               sent_len.squeeze().cpu().numpy(), batch_first=True)
        conditioning = conditioning[perm_idx_input_sent, :, :]
        packed_sent, start = get_timing(start)
        # unpacked for the word level representation
        # packed_char_vecs_output .data : [batch x shorted sent_lenS , word lens ] + .batch_sizes

        output_char_vecs, output_sizes = pad_packed_sequence(packed_char_vecs_output, batch_first=True,
                                                             padding_value=PAD_ID_WORD) # padding_value
        padd_sent, start = get_timing(start)

        # output_char_vecs : [batch ,  shorted sent_len, word len ] + .batch_sizes
        # output_char_vecs : [batch, sent_len max, dim encoder] reorder the sequence
        #output_char_vecs = output_char_vecs[inverse_perm_idx_input_sent, :, :]
        # reorder sent_len also
        #sent_len = sent_len[inverse_perm_idx_input_sent]
        # cut input_word_len so that it fits packed_padded sequence (based on output sequence)
        output_word_len = output_word_len[:, :output_char_vecs.size(1), :]
        # cut again (should be done in one step I guess) to fit source sequence (important at test time)
        output_word_len = output_word_len[:, :sent_len_max_source, :]
        # we cut output_char_vec based on ??
        output_char_vecs = output_char_vecs[:, :sent_len_max_source, :]
        output_seq = output_char_vecs.contiguous().view(output_char_vecs.size(0) * output_char_vecs.size(1), output_char_vecs.size(2))
        reshape_sent, start = get_timing(start)
        # output_seq : [ batch x max sent len, max word len  ]
        output_word_len = output_word_len.contiguous()
        output_word_len = output_word_len.view(output_word_len.size(0) * output_word_len.size(1))
        reshape_len, start = get_timing(start)
        printing("TARGET output before word encoder {}", var=[output_seq.size()], verbose=verbose, verbose_level=3)
        output_w_decoder, attention_weight_all = self.word_encoder_target(output_seq, conditioning, output_word_len,
                                                                          word_src_sizes=word_src_sizes,
                                                                          proportion_pred_train=proportion_pred_train,
                                                                          char_seq_hidden_encoder=char_seq_hidden_encoder)

        # output_w_decoder
        word_encoders, start = get_timing(start)
        # we update sent len based on how it was cut (specifically useful at test time)
        sent_len = torch.min(torch.ones_like(sent_len) * sent_len_max_source, sent_len)
        sent_len_cumulated = get_cumulated_list(sent_len)
        output_w_decoder_ls = [output_w_decoder[sent_len_cumulated[i]:sent_len_cumulated[i + 1]] for i in range(len(sent_len_cumulated) - 1)]
        output_w_decoder = pack_sequence(output_w_decoder_ls)
        output_w_decoder, _ = pad_packed_sequence(output_w_decoder, batch_first=True)
        output_w_decoder = output_w_decoder[inverse_perm_idx_input_sent,:,:]
        # output_w_decoder : [ n_sents  x max sent len, max word len , hidden_size_decoder ]
        #max_word = output_w_decoder.size(0)/output_char_vecs.size(0)
        #output_w_decoder = output_w_decoder.view(output_char_vecs.size(0), max_word, -1, output_w_decoder.size(2))
        if self.attn_layer is not None:
            attention_weight_all = torch.cat(attention_weight_all, dim=1)
            attention_weight_all_ls = [attention_weight_all[sent_len_cumulated[i]:sent_len_cumulated[i + 1]] for i in range(len(sent_len_cumulated) - 1)]
            attention_weight_all = pack_sequence(attention_weight_all_ls)
            attention_weight_all, _ = pad_packed_sequence(attention_weight_all, batch_first=True)
            attention_weight_all = attention_weight_all[inverse_perm_idx_input_sent]
            #attention_weight_all = torch.cat(attention_weight_all, dim=1)
            #attention_weight_all = attention_weight_all.view(output_char_vecs.size(0), max_word, attention_weight_all.size(1),attention_weight_all.size(2))
        else:
            attention_weight_all = None
        reshape_attention, start = get_timing(start)
        # output_w_decoder : [ batch , max sent len, max word len , hidden_size_decoder ]
        if self.timing:
            print("SENT TARGET : {}".format(OrderedDict([("clone_len", clone_len), ("argmin_squeeze", argmin_squeeze),("sorting", sorting),
                                                         ("packed_sent", packed_sent), ("padd_sent",padd_sent), ("reshape_sent",reshape_sent),
                                                         ("reshape_len",reshape_len),("word_encoders", word_encoders), ("reshape_attention",reshape_attention)])))
        return output_w_decoder, attention_weight_all
Beispiel #3
0
    def word_encoder_target_step(self, char_vec_current_batch,  state_decoder_current,
                                 char_vecs_sizes, step_char,
                                 word_stable_context,
                                 char_seq_hidden_encoder=None):
        # char_vec_current_batch is the new input character read, state_decoder_current
        # is the state of the cell (h and possibly cell)
        # should torch.cat() char_vec_current_batch with attention based context computed on char_seq_hidden_encoder

        state_hiden, state_cell = state_decoder_current[0], state_decoder_current[1] if isinstance(self.seq_decoder, nn.LSTM) else (state_decoder_current, None)
        printing("DECODER STEP : target char_vec_current_batch {} size and state_decoder_current {} and {} size",
                 var=[char_vec_current_batch.size(), state_hiden.size(), state_cell.size()],
                 verbose_level=3, verbose=self.verbose)

        printing("DECODER STEP : context  char_vec_current_batch {}, state_decoder_current {} ",
                 var=[char_vec_current_batch.size(), state_hiden.size()], verbose_level=3,
                 verbose=self.verbose)
        # attention weights computed based on character embedding + state of the decoder recurrent state
        #current_state = torch.cat((char_vec_current_batch, state_hiden.squeeze(0)), dim=1)
        char_vec_current_batch = char_vec_current_batch.unsqueeze(1)
        # current_state : for each word (in sentence batches)  1 character local target context
        # (char embedding + previous recurrent state of the decoder))
        # current_state  : dim batch x sentence max len , char embedding + hidden_dim decoder
        start_atten = time.time()

        if self.attn_layer is not None:
            #TODO : make it generic (is there no problem also if not attention ?? (bug fix)
            # we align what we decode
            state_hiden = state_hiden[:, :char_seq_hidden_encoder.size(0), :]
            attention_weights = self.attn_layer(char_state_decoder=state_hiden.squeeze(0),
                                                char_embedding_current=char_vec_current_batch,
                                                word_src_sizes=char_vecs_sizes, encoder_outputs=char_seq_hidden_encoder)
            printing("DECODER STEP : attention context {} char_seq_hidden_encoder {} ", var=[attention_weights.size(), char_seq_hidden_encoder.size()],
                     verbose_level=3, verbose=self.verbose)

            # we multiply for each batch attention matrix by our source sequence
            if char_seq_hidden_encoder.is_cuda:
                # don't know why we need to do that 
                attention_weights = attention_weights.cuda()
            # TODO HOW IS MASKING TAKEN CARE OF IN THE TARGET ? WE PACKED AND PADDED SO SHORTED THE SEQUENCE
            attention_context = attention_weights.bmm(char_seq_hidden_encoder)
            # was context
        else:
            attention_context = None
            attention_weights = None
        if self.stable_decoding_state:
            word_stable_context = word_stable_context.transpose(1, 0)
        else:
            word_stable_context = None
        if attention_context is not None or word_stable_context is not None:
            if attention_context is None:
                context = word_stable_context
            elif word_stable_context is None:
                context = attention_context
            else:
                context = torch.cat((word_stable_context, attention_context), dim=2)
            context = self.context_proj(context)

            char_vec_current_batch = torch.cat((context, char_vec_current_batch), dim=2)

        else:
            # no word level context passed so --> char_vec_current is only the current character vector  
            pass

            # compute product attention_weights with  char_seq_hidden_encoder (updated for each character)
            # this provide a new character context that we concatanate
            #  with char_vec_current_batch + possibly conditioning_other
            #  as they do
            ##https://github.com/spro/practical-pytorch/blob/master/seq2seq-translation/seq2seq-translation-batched.ipynb
            # the context is goes as input as the character embedding : we add the tranditional conditioning_other
        time_atten, start = get_timing(start_atten)

        output, state = self.seq_decoder(char_vec_current_batch, state_decoder_current)

        time_step_decoder, _ = get_timing(start)
        if self.timing:
            print("Attention time {} ".format(OrderedDict([('time_attention', time_atten),
                                                           ("time_step_decoder", time_step_decoder)])))
        # output and state are equal because we call the GRU step by step (no the entire sequence)
        return output, state, attention_weights
Beispiel #4
0
    def word_encoder_target(self, output, conditioning, output_word_len,
                            word_src_sizes=None,
                            proportion_pred_train=None,
                            char_seq_hidden_encoder=None):
        # TODO DEAL WITH MASKING (padding and prediction oriented ?)

        printing("TARGET size {} ", var=output.size(), verbose=self.verbose, verbose_level=3)
        printing("TARGET data {} ", var=output, verbose=self.verbose, verbose_level=5)
        printing("TARGET  : Word  length  {}  ".format(output_word_len), self.verbose, verbose_level=5)
        start = time.time() if self.timing else None
        output_word_len, perm_idx_output = output_word_len.squeeze().sort(0, descending=True)
        output = output[perm_idx_output,:]
        # we made the choice to mask again the
        conditioning = conditioning.view(1, conditioning.size(0) * conditioning.size(1), -1)
        conditioning = conditioning[:, perm_idx_output, :]
        reorder_conditioning, start = get_timing(start)
        perm_idx_output = perm_idx_output[output_word_len != 0]
        inverse_perm_idx_output = torch.from_numpy(np.argsort(perm_idx_output.cpu().numpy()))
        # output : [  ]
        # we remove empty token from the output_sequence and th input conditioning vector () (as we did in the input) ,
        output = output[output_word_len != 0]
        conditioning = conditioning[:, output_word_len !=0, :]
        output_word_len = output_word_len[output_word_len != 0]
        char_vecs = self.char_embedding_decoder(output)
        char_vecs = self.drop_out_char_embedding_decoder(char_vecs)
        char_embedding, start = get_timing(start)
        printing("TARGET EMBEDDING size {} ", var=[char_vecs.size()], verbose=self.verbose, verbose_level=3) #if False else None
        printing("TARGET EMBEDDING data {} ", var=char_vecs, verbose=self.verbose, verbose_level=5)
        not_printing, start = get_timing(start)
        printing("TARGET  word lengths after  {} dim",
                 var=[output_word_len.size()], verbose=self.verbose, verbose_level=3)
        # same as target sequence and source ..
        #output_word_len[output_word_len == 0] = 1
        zero_last, start = get_timing(start)
        packed_char_vecs_output = pack_padded_sequence(char_vecs, output_word_len.squeeze().cpu().numpy(), batch_first=True)
        pack_time, start = get_timing(start)
        _start_recurrence = start
        printing("TARGET packed_char_vecs {}  dim", var=[packed_char_vecs_output.data.shape], verbose=self.verbose,
                 verbose_level=3) # .size(), packed_char_vecs)
        # conditioning is the output of the encoder (work as the first initial state of the decoder)
        if isinstance(self.seq_decoder, nn.LSTM):
            stable_decoding_word_state = conditioning.clone() if self.stable_decoding_state else None
            # TODO : ugly because we have done a projection and reshaping for nothing on conditioning
            conditioning = (torch.zeros_like(conditioning), conditioning) if self.init_context_decoder \
                else (torch.zeros_like(conditioning), torch.zeros_like(conditioning))
        # attention
        if self.attn_layer is not None:
            assert char_seq_hidden_encoder is not None, 'ERROR sent_len_max_source is None'
            assert word_src_sizes is not None
        # start new unrolling by had
        # we initiate with our same original context conditioning
        if self.unrolling_word:
            # we start with context sent + word as before
            state_i = conditioning
            _output = []
            attention_weight_all = []
            # we repad it straight away cause unrolling by hand
            char_vecs, char_vecs_sizes_target = pad_packed_sequence(packed_char_vecs_output, batch_first=True)
            printing("DECODER char_vecs re-paded {} ", var=[char_vecs.data.size()], verbose=self.verbose,
                     verbose_level=3)
            max_word_len = char_vecs.size(1)
            for char_i in range(max_word_len):
                if proportion_pred_train is not None:
                    teacher_force = True if np.random.randint(0, 100) > proportion_pred_train else False
                else:
                    teacher_force = True
                if teacher_force or char_i == 0:
                    emb_char = char_vecs[:, char_i, :]
                    printing("DECODER state_decoder_current {} ", var=[state_i[0].size()], verbose=self.verbose,
                             verbose_level=3)
                    printing("DECODER emb_char {} ", var=[emb_char.size()], verbose=self.verbose, verbose_level=3)
                    if char_seq_hidden_encoder.size(0) != emb_char.size(0):
                        pass
                    all_states, state_i, attention_weights = self.word_encoder_target_step(
                                                                    char_vec_current_batch=emb_char,
                                                                    state_decoder_current=state_i,
                                                                    char_vecs_sizes=word_src_sizes,
                                                                    step_char=char_i,
                                                                    word_stable_context=stable_decoding_word_state,
                                                                    char_seq_hidden_encoder=char_seq_hidden_encoder)
                else:
                    assert self.generator is not None, "Generator must be passed in decoder for decodibg if not teacher_force"
                    # TODO based on state_i compute as generator : get id : lookup character embedding and that's it
                    # TODO : not fir the first one that should be the STARTING_SYMBOL
                    # given the current emb_char, the states of the cell (inirialized with the conditoning source )
                    #  we compute the next states
                    # [batch x sent_max_len, len_words] ??

                    decoding_states, state_i, attention_weights = self.word_encoder_target_step(char_vec_current_batch=emb_char,
                                                                                                word_stable_context=stable_decoding_word_state,
                                                                                                state_decoder_current=state_i,
                                                                                                char_vecs_sizes=word_src_sizes,
                                                                                                step_char=char_i,
                                                                                                char_seq_hidden_encoder=char_seq_hidden_encoder)
                    printing("DECODING in schedule sampling {} ", var=[state_i[0].size()], verbose=self.verbose,
                             verbose_level=3)
                    # we feed to generator to get the score and the prediction
                    # [batch x sent_max_len, len_words, hidden_dim] ??
                    scores = self.generator.forward(x=decoding_states)

                    predictions = scores.argmax(dim=-1)

                    # TODO : to confirm the shapes here
                    pred = predictions[:,  -1]

                    # given the prediction we get the next character embedding
                    emb_char = self.char_embedding_decoder(pred)

                # no more pack sequence&
                # TODO : should shorted sequence output and state by setting them to 0 using step_char and char_vecs_sizes_target (but it should be fine with the loss outpu)
                #c_i = state_i[1] if isinstance(self.seq_decoder, nn.LSTM) else None
                h_i = state_i[0] if isinstance(self.seq_decoder, nn.LSTM) else h_i
                attention_weight_all.append(attention_weights)
                _output.append(h_i.transpose(0, 1)) # for LSTM the hidden is the output not the cell
                printing("DECODER hidden out {} ", var=[h_i.transpose(0, 1).size()], verbose=self.verbose, verbose_level=3)
                printing("DECODER all_states {} ", var=[all_states.size()], verbose=self.verbose, verbose_level=3)
                #assert (all_states == h_i.transpose(0, 1)).all() == 1
            output = torch.cat(_output, dim=1)
            # we reoder char_vecs so need to do it
            output = output[inverse_perm_idx_output, :, :]
            printing("DECODER : target unrolling : output {} size ", var=[output.size()], verbose=0, verbose_level=3)
            recurrent_cell_time, pack_time, padd_time = None, None, None
        else:

            output, h_n = self.seq_decoder(packed_char_vecs_output, conditioning)
            h_n = h_n[0] if isinstance(self.seq_decoder, nn.LSTM) else h_n
            recurrent_cell_time, start = get_timing(start)
            printing("TARGET ENCODED {} output {} h_n (output (includes all the hidden states of last layers), "
                     "last hidden hidden for each dir+layers)", var=(output, h_n), verbose=self.verbose, verbose_level=5)
            printing("TARGET ENCODED SIZE {} output {} h_n (output (includes all the hidden states of last layers), "
                     "last hidden hidden for each dir+layers)", var=(output.data.shape, h_n.size()), verbose=self.verbose,
                     verbose_level=3)
            output, output_sizes = pad_packed_sequence(output, batch_first=True)
            padd_time, start = get_timing(start)
            output = output[inverse_perm_idx_output, :, :]
            printing("TARGET ENCODED UNPACKED  {} output {} h_n (output (includes all the hidden states of last layers)"
                     "last hidden hidden for each dir+layers)", var=(output, h_n), verbose=self.verbose, verbose_level=5)

            printing("TARGET ENCODED UNPACKED SIZE {} output {} h_n (output (includes all "
                     "the hidden states of last layers),"
                     "last hidden hidden for each dir+layers)", var=(output.size(), h_n.size()),
                     verbose=self.verbose, verbose_level=3)
            attention_weight_all = None
        all_recurrent_time, _ = get_timing(_start_recurrence)
        if self.timing:
            print("WORD TARGET {} ".format(OrderedDict([('char_embedding', char_embedding),
                                                        ("reorder_all", reorder_conditioning),
                                                        ("zero_last", zero_last), ("not_printing", not_printing),
                                                        ("pack_time", pack_time), ("recurrent_cell_time", recurrent_cell_time),
                                                        ("all_recurrent_time", all_recurrent_time), ("pad_time", padd_time)])))
        return output, attention_weight_all
Beispiel #5
0
    def forward(self, input_seq, input_word_len, word_embed_input=None,
                output_word_len=None, output_seq=None, word_level_predict=False,
                proportion_pred_train=None):
        # [batch, seq_len ] , batch of sequences of indexes (that corresponds to character 1-hot encoded)
        # char_vecs_input = self.char_embedding(input_seq)
        # [batch, seq_len, input_dim] n batch of sequences of embedded character
        timing = self.timing
        if self.decoder and not word_level_predict:
            assert output_seq is not None and output_word_len is not None, \
                "ERROR : output_seq is {} and output_word le, {}".format(output_seq, output_word_len)

        printing("TYPE  input_seq {} input_word_len ", var=(input_seq.is_cuda, input_word_len.is_cuda),
                 verbose=0, verbose_level=4)
        # input_seq : [batch, max sentence length, max word length] : batch of sentences
        start = time.time() if timing else None

        if self.word_embedding is not None:
            # remove padded by hand
            word_embed_input = word_embed_input.view(-1)
            word_embed_input = word_embed_input[word_embed_input != 1]
            #
            word_embed_input = self.word_embedding(word_embed_input)
            if self.word_embedding_project is not None:
                word_embed_input = self.word_embedding_project(word_embed_input)
                word_embed_input = self.dropout_word_encoder(word_embed_input)

        context, sent_len_max_source, char_seq_hidden_encoder, word_src_sizes, attention_weights_char_tag = \
            self.encoder.forward(input_seq,
                                 input_word_len,
                                 word_embed_input=word_embed_input)

        source_encoder, start = get_timing(start)
        # [] [batch, , hiden_size_decoder]
        printing("DECODER hidden state before bridge size {}", var=[context.size() if context is not None else 0], verbose=0, verbose_level=3)
        context = self.bridge(context)
        context = self.dropout_bridge(context)
        #for_decoder = nn.Tanh()(context)
        #h = self.layer_norm(h) if self.layer_norm is not None else h

        bridge, start = get_timing(start)

        printing("TYPE  encoder {} is cuda ", var=context.is_cuda, verbose=0, verbose_level=4)
        printing("DECODER hidden state after bridge size {}", var=[context.size()], verbose=0, verbose_level=3)

        norm_not_norm_hidden = self.normalize_not_normalize(nn.ReLU()(context)) if self.auxilliary_task_norm_not_norm else None

        if self.auxilliary_task_norm_not_norm:
            printing("DECODER hidden state after norm_not_norm_hidden size {}", var=[norm_not_norm_hidden.size()],
                     verbose=0, verbose_level=4)
        if self.decoder is not None and not word_level_predict:
            output, attention_weight_all = self.decoder.forward(output=output_seq,
                                                                conditioning=nn.Tanh()(context),
                                                                output_word_len=output_word_len,
                                                                word_src_sizes=word_src_sizes,
                                                                char_seq_hidden_encoder=char_seq_hidden_encoder,
                                                                proportion_pred_train=proportion_pred_train,
                                                                sent_len_max_source=sent_len_max_source)

        else:
            output = None
            attention_weight_all = None

        word_pred_state = self.word_decoder.forward(nn.ReLU()(context)) if self.word_decoder is not None else None

        pos_pred_state = self.pos_predictor.forward(nn.ReLU()(context)) if self.pos_predictor is not None else None

        edit_state = self.edit_predictor.forward(torch.sigmoid(context)) if self.edit_predictor is not None else None

        target_encoder, start = get_timing(start)
        printing("TYPE  decoder {} is cuda ", var=output.is_cuda if output is not None else None,
                 verbose=0, verbose_level=4)
        # output_score = nn.ReLU()(self.output_predictor(h_out))
        # [batch, output_voc_size], one score per output character token
        printing("DECODER full  output sequence encoded of size {} ", var=[output.size()] if output is not None else None, verbose=self.verbose, verbose_level=3)
        printing("DECODER full  output sequence encoded of {}", var=[output] if output is not None else None,
                 verbose=self.verbose, verbose_level=5)
        if timing:
            time_report = OrderedDict(
                [("source_encoder", source_encoder), ("target_encoder", target_encoder), ("bridge", bridge)])
            print("time report {}".format(time_report))

        return output, word_pred_state, pos_pred_state, norm_not_norm_hidden, edit_state, attention_weight_all, attention_weights_char_tag
Beispiel #6
0
def train(train_path,
          dev_path,
          n_epochs,
          normalization,
          dict_path=None,
          pos_specific_path=None,
          expand_vocab_dev_test=False,
          checkpointing_metric="loss-dev-all",
          batch_size=10,
          test_path=None,
          label_train="",
          label_dev="",
          use_gpu=None,
          lr=0.001,
          n_layers_word_encoder=1,
          n_layers_sent_cell=1,
          get_batch_mode_all=True,
          dropout_sent_encoder_cell=0,
          dropout_word_encoder_cell=0,
          dropout_word_decoder_cell=0,
          dropout_bridge=0,
          drop_out_word_encoder_out=0,
          drop_out_sent_encoder_out=0,
          dir_word_encoder=1,
          word_embed=False,
          word_embedding_dim=None,
          word_embedding_projected_dim=None,
          mode_word_encoding="cat",
          char_level_embedding_projection_dim=0,
          word_recurrent_cell_encoder=None,
          word_recurrent_cell_decoder=None,
          drop_out_char_embedding_decoder=0,
          hidden_size_encoder=None,
          output_dim=None,
          char_embedding_dim=None,
          hidden_size_decoder=None,
          hidden_size_sent_encoder=None,
          freq_scoring=5,
          compute_scoring_curve=False,
          score_to_compute_ls=None,
          mode_norm_ls=None,
          checkpointing=True,
          freq_checkpointing=None,
          freq_writer=None,
          model_dir=None,
          reload=False,
          model_full_name=None,
          model_id_pref="",
          print_raw=False,
          model_specific_dictionary=False,
          dir_sent_encoder=1,
          add_start_char=None,
          add_end_char=1,
          overall_label="DEFAULT",
          overall_report_dir=CHECKPOINT_DIR,
          compute_mean_score_per_sent=False,
          weight_binary_loss=1,
          dense_dim_auxilliary=None,
          dense_dim_auxilliary_2=None,
          unrolling_word=False,
          char_src_attention=False,
          debug=False,
          timing=False,
          dev_report_loss=True,
          bucketing=True,
          policy=None,
          teacher_force=True,
          proportion_pred_train=None,
          shared_context="all",
          clipping=None,
          extend_n_batch=1,
          stable_decoding_state=False,
          init_context_decoder=True,
          dense_dim_auxilliary_pos=None,
          dense_dim_auxilliary_pos_2=None,
          tasks=None,
          word_decoding=False,
          char_decoding=True,
          dense_dim_word_pred=None,
          dense_dim_word_pred_2=None,
          dense_dim_word_pred_3=None,
          symbolic_root=False,
          symbolic_end=False,
          extern_emb_dir=None,
          activation_word_decoder=None,
          activation_char_decoder=None,
          extra_arg_specific_label="",
          freezing_mode=False,
          freeze_ls_param_prefix=None,
          multi_task_loss_ponderation=None,
          max_char_len=None,
          attention_tagging=False,
          dropout_input=None,
          optimizer="adam",
          verbose=1):

    if multi_task_loss_ponderation is not None:
        sanity_check_loss_poneration(multi_task_loss_ponderation,
                                     verbose=verbose)
    if teacher_force:
        assert proportion_pred_train is None, "proportion_pred_train should be None as teacher_force mode"
    else:
        assert 100 > proportion_pred_train > 0, "proportion_pred_train should be between 0 and 100"
    auxilliary_task_norm_not_norm = "norm_not_norm" in tasks  # auxilliary_task_norm_not_norm
    auxilliary_task_pos = "pos" in tasks
    if "normalize" not in tasks:
        word_decoding = False
        char_decoding = False
    if not unrolling_word:
        assert not char_src_attention, "ERROR attention requires step by step unrolling  "
    printing("WARNING bucketing is {} ",
             var=bucketing,
             verbose=verbose,
             verbose_level=1)
    if freq_writer is None:
        freq_writer = freq_checkpointing
        printing("REPORTING freq_writer set to freq_checkpointing {}",
                 var=[freq_checkpointing],
                 verbose=verbose,
                 verbose_level=1)
    if auxilliary_task_norm_not_norm:
        printing(
            "MODEL : training model with auxillisary task (loss weighted with {})",
            var=[weight_binary_loss],
            verbose=verbose,
            verbose_level=1)
    #if compute_scoring_curve:
    #assert score_to_compute_ls is not None and mode_norm_ls is not None and freq_scoring is not None, \
    #    "ERROR score_to_compute_ls and mode_norm_ls should not be None"
    use_gpu = use_gpu_(use_gpu)
    hardware_choosen = "GPU" if use_gpu else "CPU"
    printing("{} hardware mode ",
             var=([hardware_choosen]),
             verbose_level=0,
             verbose=verbose)
    freq_checkpointing = int(
        n_epochs / 10
    ) if checkpointing and freq_checkpointing is None else freq_checkpointing
    assert add_start_char == 1, "ERROR : add_start_char must be activated due decoding behavior of output_text_"
    printing("WARNING : add_start_char is {} and add_end_char {}  ".format(
        add_start_char, add_end_char),
             verbose=verbose,
             verbose_level=0)
    printing("TRAINING : checkpointing every {} epoch",
             var=freq_checkpointing,
             verbose=verbose,
             verbose_level=1)
    if reload:
        assert model_full_name is not None and len(
            model_id_pref
        ) == 0 and model_dir is not None and dict_path is not None
    else:
        assert model_full_name is None and model_dir is None

    if not debug:
        pdb.set_trace = lambda: None

    loss_training = []
    loss_developing = []
    # was not able to use the template cause no more reinitialization of the variable
    loss_details_template = {
        'loss_seq_prediction': [],
        'other': {},
        'loss_binary': [],
        'loss_overall': []
    } if auxilliary_task_norm_not_norm else None

    # used for computed scores for early stoping if checkpoint_metric != loss and for curves plot
    evaluation_set_reporting = dev_path
    mode_norm_ls = ["all"]
    score_to_compute_ls = ["exact_match"]
    print(
        "WARNING :train.py overwriting mode_norm_ls score_to_compute_ls argument "
    )
    curve_scores = {
        score + "-" + mode_norm + "-" + REPO_DATASET[data]: []
        for score in score_to_compute_ls for mode_norm in mode_norm_ls
        for data in evaluation_set_reporting
    } if compute_scoring_curve else None

    printing("WARNING :  lr {} ".format(lr, add_start_char, add_end_char),
             verbose=verbose,
             verbose_level=0)
    printing(
        "INFO : dictionary is computed (re)created from scratch on train_path {} and dev_path {}"
        .format(train_path, dev_path),
        verbose=verbose,
        verbose_level=1)

    if not model_specific_dictionary:
        word_dictionary, char_dictionary, pos_dictionary, \
        xpos_dictionary, type_dictionary = \
        conllu_data.load_dict(dict_path=dict_path,
                              train_path=train_path,
                              dev_path=dev_path,
                              test_path=test_path,
                              word_embed_dict={},
                              dry_run=False,
                              force_new_dic=True,
                              add_start_char=add_start_char, verbose=1)

        voc_size = len(char_dictionary.instance2index) + 1
        word_voc_input_size = len(word_dictionary.instance2index) + 1
        printing("DICTIONARY ; character vocabulary is len {} : {} ",
                 var=str(
                     len(char_dictionary.instance2index) + 1,
                     char_dictionary.instance2index),
                 verbose=verbose,
                 verbose_level=0)
        _train_path, _dev_path, _add_start_char = None, None, None
    else:
        voc_size = None
        word_voc_input_size = 0
        if not reload:
            # we need to feed the model the data so that it computes the model_specific_dictionary
            _train_path = train_path
            _dev_path = dev_path
            _test_path = test_path
            _add_start_char = add_start_char
        else:
            # as it reload : we don't need data
            _train_path, _dev_path, _test_path, _add_start_char = None, None, None, None

    model = LexNormalizer(
        generator=Generator,
        expand_vocab_dev_test=expand_vocab_dev_test,
        dense_dim_auxilliary=dense_dim_auxilliary,
        dense_dim_auxilliary_2=dense_dim_auxilliary_2,
        tasks=tasks,
        weight_binary_loss=weight_binary_loss,
        dense_dim_auxilliary_pos=dense_dim_auxilliary_pos,
        dense_dim_auxilliary_pos_2=dense_dim_auxilliary_pos_2,
        load=reload,
        char_embedding_dim=char_embedding_dim,
        voc_size=voc_size,
        dir_model=model_dir,
        use_gpu=use_gpu,
        dict_path=dict_path,
        word_recurrent_cell_decoder=word_recurrent_cell_decoder,
        word_recurrent_cell_encoder=word_recurrent_cell_encoder,
        train_path=_train_path,
        dev_path=_dev_path,
        pos_specific_path=pos_specific_path,
        add_start_char=_add_start_char,
        model_specific_dictionary=model_specific_dictionary,
        dir_word_encoder=dir_word_encoder,
        drop_out_sent_encoder_cell=dropout_sent_encoder_cell,
        drop_out_word_encoder_cell=dropout_word_encoder_cell,
        drop_out_word_decoder_cell=dropout_word_decoder_cell,
        drop_out_bridge=dropout_bridge,
        drop_out_char_embedding_decoder=drop_out_char_embedding_decoder,
        drop_out_word_encoder_out=drop_out_word_encoder_out,
        drop_out_sent_encoder_out=drop_out_sent_encoder_out,
        n_layers_word_encoder=n_layers_word_encoder,
        dir_sent_encoder=dir_sent_encoder,
        n_layers_sent_cell=n_layers_sent_cell,
        hidden_size_encoder=hidden_size_encoder,
        output_dim=output_dim,
        model_id_pref=model_id_pref,
        model_full_name=model_full_name,
        hidden_size_sent_encoder=hidden_size_sent_encoder,
        shared_context=shared_context,
        unrolling_word=unrolling_word,
        char_src_attention=char_src_attention,
        word_decoding=word_decoding,
        dense_dim_word_pred=dense_dim_word_pred,
        dense_dim_word_pred_2=dense_dim_word_pred_2,
        dense_dim_word_pred_3=dense_dim_word_pred_3,
        char_decoding=char_decoding,
        mode_word_encoding=mode_word_encoding,
        char_level_embedding_projection_dim=char_level_embedding_projection_dim,
        stable_decoding_state=stable_decoding_state,
        init_context_decoder=init_context_decoder,
        symbolic_root=symbolic_root,
        symbolic_end=symbolic_end,
        word_embed=word_embed,
        word_embedding_dim=word_embedding_dim,
        word_embedding_projected_dim=word_embedding_projected_dim,
        word_embed_dir=extern_emb_dir,
        word_voc_input_size=word_voc_input_size,
        teacher_force=teacher_force,
        activation_char_decoder=activation_char_decoder,
        activation_word_decoder=activation_word_decoder,
        test_path=_test_path,
        extend_vocab_with_test=_test_path is not None,
        attention_tagging=attention_tagging,
        multi_task_loss_ponderation=
        multi_task_loss_ponderation,  # needed for save/reloading purposes
        hidden_size_decoder=hidden_size_decoder,
        verbose=verbose,
        timing=timing)

    pos_batch = auxilliary_task_pos

    if use_gpu:
        model = model.cuda()
        printing("TYPE model is cuda : {} ",
                 var=(next(model.parameters()).is_cuda),
                 verbose=verbose,
                 verbose_level=4)
        #model.decoder.attn_layer = model.decoder.attn_layer.cuda()
    if not model_specific_dictionary:
        model.word_dictionary, model.char_dictionary, model.pos_dictionary, \
        model.xpos_dictionary, model.type_dictionary = word_dictionary, char_dictionary, pos_dictionary, xpos_dictionary, type_dictionary

    starting_epoch = model.arguments["info_checkpoint"][
        "n_epochs"] if reload else 1
    reloading = "" if not reload else " reloaded from " + str(starting_epoch)
    n_epochs += starting_epoch
    if freezing_mode:
        assert freeze_ls_param_prefix is not None, "freeze_ls_param_prefix should not be None"
        printing("TRAINING : freezing is on for layers {} ",
                 var=[freeze_ls_param_prefix],
                 verbose=verbose,
                 verbose_level=1)
        for name, param in model.named_parameters():
            for freeze_param in freeze_ls_param_prefix:
                if name.startswith(freeze_param):
                    param.requires_grad = False
                    printing("TRAINING : freezing {} parameter ",
                             var=[name],
                             verbose=verbose,
                             verbose_level=1)

    _loss_dev = 1000
    checkpoint_score_saved = 1000
    _loss_train = 1000
    counter_no_deacrease = 0
    saved_epoch = 1
    if reload:
        printing(
            "TRAINING : RELOADED MODE , starting from checkpointed epoch {} ",
            var=starting_epoch,
            verbose_level=0,
            verbose=verbose)
    printing(
        "TRAINING : Running from {} to {} epochs : training on {} evaluating on {}",
        var=(starting_epoch, n_epochs, train_path, dev_path),
        verbose=verbose,
        verbose_level=0)
    starting_time = time.time()
    total_time = 0
    x_axis_epochs = []
    epoch_ls_dev = []
    epoch_ls_train = []

    train_path = [train_path] if isinstance(train_path, str) else train_path
    dev_path = [dev_path] if isinstance(dev_path, str) else dev_path

    readers_train = readers_load(
        datasets=train_path,
        tasks=tasks,
        word_dictionary=model.word_dictionary,
        word_dictionary_norm=model.word_nom_dictionary,
        char_dictionary=model.char_dictionary,
        pos_dictionary=model.pos_dictionary,
        xpos_dictionary=model.xpos_dictionary,
        type_dictionary=model.type_dictionary,
        use_gpu=use_gpu,
        norm_not_norm=auxilliary_task_norm_not_norm,
        word_decoder=word_decoding,
        add_start_char=add_start_char,
        add_end_char=add_end_char,
        symbolic_end=symbolic_end,
        symbolic_root=symbolic_root,
        bucket=bucketing,
        max_char_len=max_char_len,
        verbose=verbose)

    readers_dev = readers_load(datasets=dev_path,
                               tasks=tasks,
                               word_dictionary=model.word_dictionary,
                               word_dictionary_norm=model.word_nom_dictionary,
                               char_dictionary=model.char_dictionary,
                               pos_dictionary=model.pos_dictionary,
                               xpos_dictionary=model.xpos_dictionary,
                               type_dictionary=model.type_dictionary,
                               use_gpu=use_gpu,
                               norm_not_norm=auxilliary_task_norm_not_norm,
                               word_decoder=word_decoding,
                               add_start_char=add_start_char,
                               add_end_char=add_end_char,
                               symbolic_end=symbolic_end,
                               symbolic_root=symbolic_root,
                               bucket=bucketing,
                               max_char_len=max_char_len,
                               verbose=verbose)

    dir_writer = os.path.join(overall_report_dir, "runs",
                              "{}-model".format(model.model_full_name))
    writer = SummaryWriter(log_dir=dir_writer)
    printing(
        "REPORT : run \ntensorboard --logdir={} --host=localhost --port=9101 "
        "(run tensorboard remotely : sh $EXPERIENCE/track/run_tensorboard_serveo.sh $log_dir $port )  ",
        var=[dir_writer],
        verbose=verbose,
        verbose_level=1)
    printing("REPORT : summary writer will be located {}",
             var=[dir_writer],
             verbose_level=1,
             verbose=verbose)
    step_train = 0
    step_dev = 0
    if ADAPTABLE_SCORING:
        printing("WARNING : scoring epochs not regular (more at the begining ",
                 verbose_level=1,
                 verbose=verbose)
        freq_scoring = 1
    checkpoint_dir_former = None

    for epoch in tqdm(range(starting_epoch, n_epochs),
                      disable_tqdm_level(verbose=verbose, verbose_level=0)):
        index_look = 25
        #parameters = filter(lambda p: p.requires_grad, model.parameters())
        decay_rate = 1
        opt = dptx.get_optimizer(model.parameters(),
                                 lr=lr * decay_rate**epoch,
                                 optimizer="adam")
        assert policy in AVAILABLE_SCHEDULING_POLICIES
        policy_dic = eval(policy)(epoch) if policy is not None else None
        #TODO : no need of re-ouptuting multi_task_mode : tasks should be harmonized to read
        multi_task_mode, ponderation_normalize_loss, weight_binary_loss, weight_pos_loss = scheduling_policy(
            epoch=epoch, phases_ls=policy_dic, tasks=tasks)
        printing(
            "TRAINING Tasks scheduling : ponderation_normalize_loss is {} weight_binary_loss is {}"
            " weight_pos_loss is {} mode is {} ",
            var=[
                ponderation_normalize_loss, weight_binary_loss,
                weight_pos_loss, multi_task_mode
            ],
            verbose=verbose,
            verbose_level=2)

        printing("TRAINING : Starting {} epoch out of {} ",
                 var=(epoch + 1, n_epochs),
                 verbose=verbose,
                 verbose_level=1)
        model.train()
        #batchIter = data_gen_conllu(data_read_train,model.word_dictionary, model.char_dictionary,normalization=normalization,get_batch_mode=get_batch_mode_all,batch_size=batch_size, extend_n_batch=extend_n_batch,print_raw=print_raw, timing=timing, pos_dictionary=model.pos_dictionary,verbose=verbose)
        batchIter = data_gen_multi_task_sampling_batch(
            tasks=tasks,
            readers=readers_train,
            batch_size=batch_size,
            word_dictionary=model.word_dictionary,
            char_dictionary=model.char_dictionary,
            pos_dictionary=model.pos_dictionary,
            word_dictionary_norm=model.word_nom_dictionary,
            get_batch_mode=get_batch_mode_all,
            extend_n_batch=extend_n_batch,
            dropout_input=dropout_input,
            verbose=verbose)
        start = time.time()
        printing(
            "TRAINING : TEACHER FORCE : Schedule Sampling proportion of train on prediction is {} ",
            var=[proportion_pred_train],
            verbose=verbose,
            verbose_level=2)

        #rep_tl.checkout_layer_name("encoder.seq_encoder.weight_ih_l0", model.named_parameters(), info_epoch=epoch)

        loss_train, loss_details_train, step_train = run_epoch(
            batchIter,
            model,
            LossCompute(
                model.generator,
                opt=opt,
                multi_task_loss_ponderation=model.multi_task_loss_ponderation,
                auxilliary_task_norm_not_norm=auxilliary_task_norm_not_norm,
                model=model,
                writer=writer,
                use="train",
                use_gpu=use_gpu,
                verbose=verbose,
                tasks=tasks,
                char_decoding=char_decoding,
                word_decoding=word_decoding,
                pos_pred=auxilliary_task_pos,
                vocab_char_size=len(
                    list(model.char_dictionary.instance2index.keys())) + 1,
                timing=timing),
            verbose=verbose,
            i_epoch=epoch,
            multi_task_mode=multi_task_mode,
            n_epochs=n_epochs,
            timing=timing,
            weight_binary_loss=weight_binary_loss,
            weight_pos_loss=weight_pos_loss,
            ponderation_normalize_loss=ponderation_normalize_loss,
            step=step_train,
            clipping=clipping,
            pos_batch=pos_batch,
            proportion_pred_train=proportion_pred_train,
            log_every_x_batch=100)

        writer_weights_and_grad(model=model,
                                freq_writer=freq_writer,
                                epoch=epoch,
                                writer=writer,
                                verbose=verbose)

        _train_ep_time, start = get_timing(start)
        model.eval()
        # TODO : should be added in the freq_checkpointing orhterwise useless
        #batchIter_eval = data_gen_conllu(data_read_dev,model.word_dictionary, model.char_dictionary,batch_size=batch_size, get_batch_mode=False,normalization=normalization, extend_n_batch=1,pos_dictionary=model.pos_dictionary, verbose=verbose)
        batchIter_eval = data_gen_multi_task_sampling_batch(
            tasks=tasks,
            readers=readers_dev,
            batch_size=batch_size,
            word_dictionary=model.word_dictionary,
            char_dictionary=model.char_dictionary,
            word_dictionary_norm=model.word_nom_dictionary,
            pos_dictionary=model.pos_dictionary,
            dropout_input=0,
            extend_n_batch=1,
            get_batch_mode=False,
            verbose=verbose)
        _create_iter_time, start = get_timing(start)
        # TODO : should be able o factorize this to have a single run_epoch() for train and dev (I think the computaiton would be same )
        # TODO : should not evaluate for each epoch : should evalaute every x epoch : check if it decrease and checkpoint
        if (dev_report_loss and
            (epoch % freq_checkpointing == 0)) or (epoch + 1 == n_epochs):
            printing("EVALUATION : computing loss on dev epoch {}  ",
                     var=epoch,
                     verbose=verbose,
                     verbose_level=1)
            loss_obj = LossCompute(
                model.generator,
                use_gpu=use_gpu,
                verbose=verbose,
                multi_task_loss_ponderation=model.multi_task_loss_ponderation,
                writer=writer,
                use="dev",
                vocab_char_size=len(
                    list(model.char_dictionary.instance2index.keys())) + 1,
                pos_pred=auxilliary_task_pos,
                tasks=tasks,
                char_decoding=char_decoding,
                word_decoding=word_decoding,
                auxilliary_task_norm_not_norm=auxilliary_task_norm_not_norm)
            loss_dev, loss_details_dev, step_dev = run_epoch(
                batchIter_eval,
                model,
                loss_compute=loss_obj,
                i_epoch=epoch,
                n_epochs=n_epochs,
                verbose=verbose,
                timing=timing,
                step=step_dev,
                weight_binary_loss=weight_binary_loss,
                ponderation_normalize_loss=ponderation_normalize_loss,
                weight_pos_loss=weight_pos_loss,
                pos_batch=pos_batch,
                log_every_x_batch=100)

            loss_developing.append(loss_dev)
            epoch_ls_dev.append(epoch)

            if auxilliary_task_norm_not_norm:
                # in this case we report loss detail
                for ind, loss_key in enumerate(loss_details_dev.keys()):
                    if loss_key != "other":
                        loss_details_template[loss_key].append(
                            loss_details_dev[loss_key])
            else:
                loss_details_template = None

        _eval_time, start = get_timing(start)
        loss_training.append(loss_train)
        epoch_ls_train.append(epoch)
        time_per_epoch = time.time() - starting_time
        total_time += time_per_epoch
        starting_time = time.time()

        # computing exact/edit score
        exact_only = False
        overall_report_ls = None
        # MODIFIED FREQ SCORING TO FREQ CHECKPOINTING

        if compute_scoring_curve and (
            (epoch %
             (freq_checkpointing if checkpointing_metric != "loss-dev-all" else
              freq_scoring) == 0) or (epoch + 1 == n_epochs)):
            if epoch < 1 and ADAPTABLE_SCORING:
                freq_scoring *= 5
            if epoch > 4 and epoch < 6 and ADAPTABLE_SCORING:
                freq_scoring *= 3
            if epoch > 14 and epoch < 15 and ADAPTABLE_SCORING:
                freq_scoring *= 2
            if (epoch + 1 == n_epochs):
                printing("EVALUATION : final scoring ",
                         verbose,
                         verbose_level=0)
            x_axis_epochs.append(epoch)
            printing("EVALUATION : Computing score on {} and {}  ",
                     var=(score_to_compute_ls, mode_norm_ls),
                     verbose=verbose,
                     verbose_level=1)
            overall_report_ls = []
            for task, eval_data in zip(tasks, evaluation_set_reporting):
                eval_label = REPO_DATASET[eval_data]
                assert len(set(evaluation_set_reporting)) == len(evaluation_set_reporting),\
                    "ERROR : twice the same dataset has been provided for reporting which will mess up the loss"
                printing("EVALUATION on {} ",
                         var=[eval_data],
                         verbose=verbose,
                         verbose_level=1)
                scores = evaluate(
                    data_path=eval_data,
                    use_gpu=use_gpu,
                    overall_label=overall_label,
                    overall_report_dir=overall_report_dir,
                    score_to_compute_ls=score_to_compute_ls,
                    mode_norm_ls=mode_norm_ls,
                    label_report=eval_label,
                    model=model,
                    normalization=normalization,
                    print_raw=False,
                    model_specific_dictionary=True,
                    get_batch_mode_evaluate=False,
                    compute_mean_score_per_sent=compute_mean_score_per_sent,
                    batch_size=batch_size,
                    word_decoding=word_decoding,
                    dir_report=model.dir_model,
                    debug=debug,
                    evaluated_task=task,
                    tasks=tasks,
                    verbose=verbose)
                # we keep everythinghere in case we want to do some fancy early stopping metric
                overall_report_ls.extend(scores)

                # dirty but do the job
                exact_only = True
                DEPRECIATED = False
                if DEPRECIATED:
                    curve_scores = update_curve_dic(
                        score_to_compute_ls=score_to_compute_ls,
                        mode_norm_ls=mode_norm_ls,
                        eval_data=eval_label,
                        former_curve_scores=curve_scores,
                        scores=scores,
                        exact_only=exact_only)
                    curve_ls_tuple = [
                        (loss_ls, label)
                        for label, loss_ls in curve_scores.items()
                        if isinstance(loss_ls, list)
                    ]
                    curves = [tupl[0] for tupl in curve_ls_tuple]
                    val_ls = [
                        tupl[1] + "({}tok)".format(info_token)
                        for tupl in curve_ls_tuple
                        for data, info_token in curve_scores.items()
                        if not isinstance(info_token, list)
                        if tupl[1].endswith(data)
                    ]
            score_to_compute_ls = ["exact"
                                   ] if exact_only else score_to_compute_ls
            if DEPRECIATED:
                for score_plot in score_to_compute_ls:
                    # dirty but do the job
                    print(val_ls)
                    if exact_only:
                        val_ls = [
                            val for val in val_ls
                            if val.startswith("exact-all")
                            or val.startswith("exact-NORMED")
                            or val.startswith("exact-NEED_NORM")
                        ]
                        #val_ls = ["{}-all-{}".format(metric,REPO_DATASET[eval]) for eval in evaluation_set_reporting for metric in ["exact", "edit"]]
                        curves = [curve for curve in curves if len(curve) > 0]

                    simple_plot_ls(losses_ls=curves,
                                   labels=val_ls,
                                   final_loss="",
                                   save=True,
                                   filter_by_label=score_plot,
                                   x_axis=x_axis_epochs,
                                   dir=model.dir_model,
                                   prefix=model.model_full_name,
                                   epochs=str(epoch) + reloading,
                                   verbose=verbose,
                                   lr=lr,
                                   label_color_0=REPO_DATASET[
                                       evaluation_set_reporting[0]],
                                   label_color_1=REPO_DATASET[
                                       evaluation_set_reporting[1]])

        # WARNING : only saving if we decrease not loading former model if we relaod
        if (checkpointing
                and epoch % freq_checkpointing == 0) or (epoch + 1
                                                         == n_epochs):
            if checkpointing_metric != "loss-dev-all" and epoch < STARTING_CHECKPOINTING_WITH_SCORE:
                _checkpointing_metric = "loss-dev-all"
            elif checkpointing_metric != "loss-dev-all":
                _checkpointing_metric = checkpointing_metric
                if epoch == STARTING_CHECKPOINTING_WITH_SCORE:
                    checkpoint_score_saved = -report["score"]
                    printing("Checkoint info : switching "
                             "checkpoint_score_saved to {} : {}".format(
                                 checkpointing_metric, checkpoint_score_saved),
                             verbose_level=1,
                             verbose=verbose)
            elif checkpointing_metric == "loss-dev-all":
                _checkpointing_metric = checkpointing_metric
            else:
                raise (Exception("You missed a case"))

            dir_plot_detailed = simple_plot(
                final_loss=0,
                epoch_ls_1=epoch_ls_dev,
                epoch_ls_2=epoch_ls_dev,
                loss_2=loss_details_template.get("loss_binary", None),
                loss_ls=loss_details_template["loss_seq_prediction"],
                epochs=str(epoch) + reloading,
                label="dev-seq_prediction",
                label_2="dev-binary",
                save=True,
                dir=model.dir_model,
                verbose=verbose,
                verbose_level=1,
                lr=lr,
                prefix=model.model_full_name + "-details",
                show=False) if loss_details_template is not None else None

            dir_plot = simple_plot(final_loss=loss_train,
                                   loss_2=loss_developing,
                                   loss_ls=loss_training,
                                   epochs=str(epoch) + reloading,
                                   epoch_ls_1=epoch_ls_train,
                                   epoch_ls_2=epoch_ls_dev,
                                   label=label_train + "-train",
                                   label_2=label_dev + "-dev",
                                   save=True,
                                   dir=model.dir_model,
                                   verbose=verbose,
                                   verbose_level=1,
                                   lr=lr,
                                   prefix=model.model_full_name,
                                   show=False)

            sanity_check_checkpointing_metric(
                tasks, checkpointing_metric=_checkpointing_metric)

            if _checkpointing_metric != "loss-dev-all" or \
                    (epoch == (STARTING_CHECKPOINTING_WITH_SCORE-1) and checkpointing_metric != "loss-dev-all"):
                # for now only useful when different from loss --> compute metric on dev all and default always
                # assuing unitask thanks to sanity check
                assert overall_report_ls is not None, "ERROR overall_report_ls  was not defined "
                report = rep_tl.get_score(
                    overall_report_ls,
                    metric=TASKS_PARAMETER[tasks[0]].get("default_metric"),
                    data=REPO_DATASET[dev_path[0]],
                    info_score="all",
                    task=tasks[0])
                # Negative cause it's an accuracy
                checkpoint_score = -report["score"]
            else:
                checkpoint_score = loss_dev

            model, checkpoint_score_saved, counter_no_deacrease, saved_epoch, checkpoint_dir_former = \
                    checkpoint(loss_saved=checkpoint_score_saved, loss=checkpoint_score,
                               checkpointing_metric=_checkpointing_metric,
                               model=model, counter_no_decrease=counter_no_deacrease,
                               checkpoint_dir_former=checkpoint_dir_former,
                               saved_epoch=saved_epoch, model_dir=model.dir_model,
                               extra_checkpoint_label="1st_train" if not reload else "start_{}_ep-{}".format(starting_epoch, extra_arg_specific_label),
                               extra_arg_specific_label=extra_arg_specific_label,
                               info_checkpoint={"n_epochs": epoch, "batch_size": batch_size, "optimizer": optimizer,
                                                "gradient_clipping": clipping,
                                                "tasks_schedule_policy": policy,
                                                "teacher_force": teacher_force,
                                                "proportion_pred_train": proportion_pred_train,
                                                "train_data_path": train_path, "dev_data_path": dev_path,
                                                "other": {"error_curves": dir_plot, "loss": loss_dev,
                                                          "sanity_test": {"loss": loss_dev,
                                                                          "data": [REPO_DATASET[_dev_path] for _dev_path in dev_path],
                                                                          "batch_size": batch_size},
                                                          "error_curves_details": dir_plot_detailed,
                                                          "dropout_input": dropout_input,
                                                          "checkpointing_metric": _checkpointing_metric,
                                                          "multi_task_loss_ponderation": multi_task_loss_ponderation,
                                                          "weight_binary_loss": weight_binary_loss*int(auxilliary_task_norm_not_norm),
                                                          "weight_pos_loss": weight_pos_loss*int(auxilliary_task_pos),
                                                          "ponderation_normalize_loss": ponderation_normalize_loss,
                                                          "data": "dev", "seed(np/torch)": (SEED_NP, SEED_TORCH),
                                                          "extend_n_batch": extend_n_batch,
                                                          "lr": lr, "optim_strategy": "lr_constant",
                                                          "time_training(min)": "{0:.2f}".format(total_time/60),
                                                          "average_per_epoch(min)": "{0:.2f}".format((total_time/n_epochs)/60)}},
                               epoch=epoch, epochs=n_epochs-1,
                               keep_all_checkpoint=False if epoch > starting_epoch else True,# we have nothing to remove after 1st epoch
                               verbose=verbose)
            if counter_no_deacrease * freq_checkpointing >= BREAKING_NO_DECREASE:
                printing(
                    "CHECKPOINTING : Breaking training : loss did not decrease on dev for 10 checkpoints "
                    "so keeping model from {} epoch  ".format(saved_epoch),
                    verbose=verbose,
                    verbose_level=0)
                break
        printing(
            "LOSS train {:.3f}, dev {:.3f} for epoch {} out of {} epochs ",
            var=(loss_train, loss_dev, epoch, n_epochs),
            verbose=verbose,
            verbose_level=1)

        if timing:
            print("Summary : {}".format(
                OrderedDict([("_train_ep_time", _train_ep_time),
                             ("_create_iter_time", _create_iter_time),
                             ("_eval_time", _eval_time)])))

    writer.close()
    printing(
        "REPORT : run : \n tensorboard --logdir={} --host=localhost --port=9101  ",
        var=[dir_writer],
        verbose=verbose,
        verbose_level=1)

    #rep_tl.checkout_layer_name("encoder.seq_encoder.weight_ih_l0", model.named_parameters(), info_epoch="LAST")

    simple_plot(final_loss=loss_dev,
                loss_ls=loss_training,
                loss_2=loss_developing,
                epoch_ls_1=epoch_ls_train,
                epoch_ls_2=epoch_ls_dev,
                epochs=n_epochs,
                save=True,
                dir=model.dir_model,
                label=label_train,
                label_2=label_dev,
                lr=lr,
                prefix=model.model_full_name + "-LAST",
                verbose=verbose)

    return model.model_full_name
def run_epoch(
        data_iter,
        model,
        loss_compute,
        weight_binary_loss,
        weight_pos_loss,
        ponderation_normalize_loss,
        verbose=0,
        i_epoch=None,
        n_epochs=None,
        n_batches=None,
        empty_run=False,
        timing=False,
        multi_task_mode="all",
        clipping=None,
        step=0,
        proportion_pred_train=None,
        pos_batch=False,
        # should be added in policy
        log_every_x_batch=VERBOSE_1_LOG_EVERY_x_BATCH):
    "Standard Training and Logging Function"

    assert multi_task_mode in AVAILABLE_TASKS
    _start = time.time()
    total_tokens = 0
    total_loss = 0
    total_loss_details = loss_compute.loss_details_template.copy()
    tokens = 0
    i_epoch = -1 if i_epoch is None else i_epoch
    n_epochs = -1 if n_epochs is None else n_epochs
    batch_time_start = time.time()
    i = 0
    while True:
        try:
            batch = data_iter.__next__()
            i += 1
        except StopIteration:
            break
        batch_time_, batch_time_start = get_timing(batch_time_start)
        printing("Starting {} batch out of {} batches",
                 var=(i + 1, n_batches),
                 verbose=verbose,
                 verbose_level=1)
        if not empty_run:
            start = time.time() if timing else None
            out, out_word, pos_pred_state, norm_not_norm_hidden, edit_state, attention, attention_tag = model.forward(
                input_seq=batch.input_seq,
                output_seq=batch.output_seq_x,
                input_word_len=batch.input_seq_len,
                output_word_len=batch.output_seq_len,
                proportion_pred_train=proportion_pred_train,
                word_embed_input=batch.input_word)
            forward_time, start = get_timing(start)
        else:
            out = 0, _
            printing("DATA : \n input Sequence {} \n Target sequence {} ",
                     var=(batch.input_seq, batch.output_seq),
                     verbose=verbose,
                     verbose_level=1)
        if not empty_run:
            loss, loss_details_current = loss_compute(
                x=out,
                y=batch.output_seq_y,
                x_norm_not_norm=norm_not_norm_hidden,
                y_norm_not_norm=batch.output_norm_not_norm,
                y_word=batch.output_word,
                x_word_pred=out_word,
                y_pos=batch.pos,
                x_pos=pos_pred_state,
                pos_batch=pos_batch,
                y_edit=batch.edit,
                pred_edit=edit_state,
                weight_binary_loss=weight_binary_loss,
                weight_pos_loss=weight_pos_loss,
                ponderation_normalize_loss=ponderation_normalize_loss,
                clipping=clipping,
                step=i + step)

            loss_time, start = get_timing(start)
            total_loss += loss.item()
            total_loss_details = update_loss_details_dic(
                total_loss_details, loss_details_current)
            total_tokens += batch.ntokens.type(torch.FloatTensor)
            tokens += batch.ntokens.type(torch.FloatTensor)
            elapsed = torch.from_numpy(np.array(time.time() - _start)).float()
            _start = time.time() if verbose >= 2 else _start
            _loss = loss / float(batch.ntokens)
            printing(
                "Epoch {} Step: {}  Loss: {}  Tokens per Sec: {} , total tokens {} : this batch {} within {} sent",
                var=(-i_epoch + 1, i, _loss, tokens / elapsed, tokens,
                     batch.ntokens.type(torch.FloatTensor),
                     batch.input_seq.size(0)),
                verbose=verbose,
                verbose_level=1)
            tokens = 0 if verbose >= 2 else tokens
            if i % log_every_x_batch == 1 and verbose == 1:
                print(
                    "Epoch {} Step: {}  Loss: {}  Tokens per Sec: {} , total tokens {}"
                    .format(i_epoch, i, loss / float(batch.ntokens),
                            tokens / elapsed, tokens))
                _start = time.time()
                tokens = 0
        else:
            total_loss, total_tokens = 0, 1
        batch_time_start = time.time()
        if timing:
            print("run epoch timing : {}".format(
                OrderedDict([("forward_time", forward_time),
                             ("loss_time", loss_time),
                             ("batch_time_start", batch_time_)])))
    if not empty_run:
        printing("INFO : epoch {} done ",
                 var=(n_epochs),
                 verbose=verbose,
                 verbose_level=1)
        printing("Loss epoch {} is  {} total out of {} tokens ",
                 var=(i_epoch, float(total_loss) / int(total_tokens),
                      total_tokens),
                 verbose=verbose,
                 verbose_level=1)

    total_loss_details = divide_loss_details_n_tokens(total_loss_details,
                                                      total_tokens)
    step = step + i

    return float(total_loss) / int(total_tokens), total_loss_details, step
Beispiel #8
0
    def __call__(self,
                 x,
                 y,
                 x_norm_not_norm=None,
                 y_norm_not_norm=None,
                 y_pos=None,
                 x_pos=None,
                 y_word=None,
                 x_word_pred=None,
                 y_edit=None,
                 pred_edit=None,
                 pos_batch=False,
                 weight_binary_loss=1,
                 weight_pos_loss=0,
                 ponderation_normalize_loss=0,
                 clipping=None,
                 step=None):

        if clipping is not None:
            assert self.model is not None, "Using clipping requires passing the model in the loss"
        loss_details = self.loss_details_template.copy()
        if self.loss_binary is not None:
            assert x_norm_not_norm is not None , \
                "ERROR : auxilliary_task_norm_not_norm was set to True but x_norm_not_norm {} ".format(x_norm_not_norm)
        printing("LOSS decoding states {} ",
                 var=[x.size()] if x is not None else None,
                 verbose=self.verbose,
                 verbose_level=3)
        start = time.time() if self.timing else None
        x = self.generator(x) if x is not None else None
        generate_time, start = get_timing(start)
        if self.use_gpu:
            printing("LOSS : use gpu is True", self.verbose, verbose_level=3)
        if x is not None:
            printing("LOSS input x candidate scores size {} ",
                     var=[x.size()],
                     verbose=self.verbose,
                     verbose_level=4)
            printing("LOSS input y observations size {} ",
                     var=[y.size()],
                     verbose=self.verbose,
                     verbose_level=4)
            printing("LOSS input x candidate scores   {} ",
                     var=(x),
                     verbose=self.verbose,
                     verbose_level=4)
            printing("LOSS input x candidate scores  reshaped {} ",
                     var=(x.view(-1, x.size(-1))),
                     verbose=self.verbose,
                     verbose_level=4)
            printing("LOSS input y observations {} reshaped {} ",
                     var=(y, y.contiguous().view(-1)),
                     verbose=self.verbose,
                     verbose_level=4)
            # we remove empty words in the gold
        y = y[:, :x.size(1), :] if x is not None else None
        y_edit = y_edit[:, :pred_edit.size(1)] if y_edit is not None else None
        y_norm_not_norm = y_norm_not_norm[:, :x_norm_not_norm.size(
            1)] if y_norm_not_norm is not None else None
        y_word = y_word[:, :x_word_pred.size(
            1)] if y_word is not None and x_word_pred is not None else None
        # WE PUT THAT HERE IN CASE LATER : WE WANT To ADAPT BATCH-WISE PONDERATION
        loss_scheduler = schedule_training(
            multi_task_loss_ponderation=self.multi_task_loss_ponderation)
        scheduling_normalize = loss_scheduler["normalize"]
        schedule_pos = loss_scheduler["pos"]
        scheduling_norm_not_norm = loss_scheduler["norm_not_norm"]
        scheduling_edit = loss_scheduler["edit_prediction"]

        if y is not None:
            printing("TYPE  y {} is cuda ",
                     var=(y.is_cuda),
                     verbose=0,
                     verbose_level=5)
        reshaping, start = get_timing(start)
        if self.loss_distance is not None:
            loss_generation = self.loss_distance(
                x.contiguous().view(-1, x.size(-1)),
                y.contiguous().view(-1))
        elif self.loss_distance_word_level is not None:
            loss_generation = self.loss_distance_word_level(
                x_word_pred.contiguous().view(-1, x_word_pred.size(-1)),
                y_word.contiguous().view(-1))
            assert ponderation_normalize_loss is not None
            assert scheduling_normalize is not None
        else:
            loss_generation = 0

        loss_distance_time, start = get_timing(start)
        loss_binary = 0
        if self.loss_binary is not None:
            if y_norm_not_norm is not None:
                loss_binary = self.loss_binary(
                    x_norm_not_norm.contiguous().view(
                        -1, x_norm_not_norm.size(-1)),
                    y_norm_not_norm.contiguous().view(-1))

            assert weight_binary_loss is not None
            assert scheduling_norm_not_norm is not None
        loss_edit = 0
        if self.loss_edit is not None and y_edit is not None:
            # self.loss_edit tells us it model has ability to predict edit, y_edit if we provided labels (could use pred_edit also in a way)
            assert pred_edit is not None, "ERROR pred_edit was given as None while model has a  _edit and we got label"
            loss_edit = self.loss_edit(pred_edit.contiguous().view(-1),
                                       y_edit.contiguous().view(-1))

        if pos_batch and self.loss_distance_pos is not None:
            assert x_pos is not None and y_pos is not None, "ERROR x_pos and y_pos should be define "
            y_pos = y_pos[:, :x_pos.size(1)]
            loss_pos = self.loss_distance_pos(
                x_pos.contiguous().view(-1, x_pos.size(-1)),
                y_pos.contiguous().view(-1))
        else:
            loss_pos = 0

        multi_task_loss = ponderation_normalize_loss*scheduling_normalize*loss_generation +\
                          weight_binary_loss*loss_binary*scheduling_norm_not_norm+\
                          schedule_pos*loss_pos*weight_pos_loss+\
                          loss_edit*scheduling_edit

        if False:
            predictions = x.argmax(dim=-1)
            print("STEP ", step)
            print("DEBUG LOSS", multi_task_loss, loss_generation)
            print("PREDICTIONS", predictions)
            print("LABEL", y)
            if step == 60:
                pdb.set_trace()
        loss_details["overall_loss"] = multi_task_loss.item()
        loss_details["loss_seq_prediction"] = loss_generation.item(
        ) if not isinstance(loss_generation, int) else 0
        loss_details["loss_binary"] = loss_binary.item() if not isinstance(
            loss_binary, int) else 0
        loss_details["loss_pos"] = loss_pos.item() if not isinstance(
            loss_pos, int) else 0
        loss_details["loss_edit"] = loss_edit.item() if not isinstance(
            loss_edit, int) else 0

        if not isinstance(loss_binary, int):
            printing("LOSS BINARY loss size {} ",
                     var=(str(loss_binary.size())),
                     verbose=self.verbose,
                     verbose_level=3)
            printing("TYPE  loss_binary {} is cuda ",
                     var=(loss_binary.is_cuda),
                     verbose=0,
                     verbose_level=5)

        if self.writer is not None:
            self.writer.add_scalars(
                "loss-" + self.use, {
                    "loss-{}-seq_pred".format(self.use):
                    loss_generation.clone().cpu().data.numpy()
                    if not isinstance(loss_generation, int) else 0,
                    "loss-{}-seq_pred-ponderation_normalize_loss".format(self.use):
                    loss_generation.clone().cpu().data.numpy() *
                    ponderation_normalize_loss
                    if not isinstance(loss_generation, int) else 0,
                    "loss-{}-multitask".format(self.use):
                    multi_task_loss.clone().cpu().data.numpy(),
                    "loss-{}-loss_binary".format(self.use):
                    loss_binary.clone().cpu().data.numpy()
                    if not isinstance(loss_binary, int) else 0,
                    "loss-{}-loss_pos-schedule_pos".format(self.use):
                    loss_pos.clone().cpu().data.numpy() * schedule_pos *
                    weight_pos_loss if not isinstance(loss_pos, int) else 0,
                    "loss-{}-loss_pos".format(self.use):
                    loss_pos.clone().cpu().data.numpy()
                    if not isinstance(loss_pos, int) else 0,
                    "loss-{}-loss_edit".format(self.use):
                    loss_edit.clone().cpu().data.numpy()
                    if not isinstance(loss_edit, int) else 0,
                    "loss-{}-loss_edit-schedule_edit".format(self.use):
                    scheduling_edit * loss_edit.clone().cpu().data.numpy()
                    if not isinstance(loss_edit, int) else 0,
                }, step)

        if self.opt is not None:
            self.opt.zero_grad()
            multi_task_loss.backward()
            loss_backwrd_time, start = get_timing(start)
            if clipping is not None:
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               clipping)
            gradient_clipping, start = get_timing(start)
            printing("Optimizing", self.verbose, verbose_level=3)
            self.opt.step()
            step_opt_time, start = get_timing(start)
            zero_gradtime, start = get_timing(start)
        else:
            printing(
                "WARNING no optimization : is backward required here ? (loss.py) ",
                verbose=self.verbose,
                verbose_level=3)
        if self.timing:
            print("run loss timing : {} ".format(
                OrderedDict([("loss_distance_time", loss_distance_time),
                             ("reshaping", reshaping),
                             ("generate_time", generate_time),
                             ("loss_backwrd_time", loss_backwrd_time),
                             ("gradient_clipping", gradient_clipping),
                             ("step_opt_time", step_opt_time),
                             ("zerp_gradtime", zero_gradtime)])))
        return multi_task_loss, loss_details