Esempio n. 1
0
    def __init__(self,source_sentence,samples, model, data_stream, model_name,config,
                 src_vocab=None, n_best=1, track_n_models=1, trg_ivocab=None,
                 patience=10, normalize=True, **kwargs):
        super(perplexityValidation, self).__init__(**kwargs)
        self.model = model
        self.data_stream = data_stream
        self.model_name = model_name
        self.src_vocab = src_vocab
        self.trg_ivocab = trg_ivocab
        self.is_synced = False
        self.sampling_fn = model.get_theano_function()

        self.source_sentence = source_sentence
        self.samples = samples
        self.config = config
        self.n_best = n_best
        self.normalize = normalize
        self.patience = patience

        # Helpers
        self.vocab = data_stream.dataset.dictionary
        self.trg_ivocab = trg_ivocab
        self.unk_sym = data_stream.dataset.unk_token
        self.eos_sym = data_stream.dataset.eos_token
        self.unk_idx = self.vocab[self.unk_sym]
        self.eos_idx = self.vocab[self.eos_sym]
        self.src_eos_idx = config['src_vocab_size'] - 1
        self.beam_search = BeamSearch(samples=samples)
Esempio n. 2
0
    def __init__(self, source_sentence, samples, model, data_stream,
                 config, n_best=1, track_n_models=1, trg_ivocab=None,
                 patience=10, normalize=True, **kwargs):
        super(BleuValidator, self).__init__(**kwargs)
        self.source_sentence = source_sentence
        self.samples = samples
        self.model = model
        self.data_stream = data_stream
        self.config = config
        self.n_best = n_best
        self.track_n_models = track_n_models
        self.normalize = normalize
        self.patience = patience

        # Helpers
        self.vocab = data_stream.dataset.dictionary
        self.trg_ivocab = trg_ivocab
        self.unk_sym = data_stream.dataset.unk_token
        self.eos_sym = data_stream.dataset.eos_token
        self.unk_idx = self.vocab[self.unk_sym]
        self.eos_idx = self.vocab[self.eos_sym]
        self.src_eos_idx = config['src_vocab_size'] - 1
        self.best_models = []
        self.beam_search = BeamSearch(samples=samples)
        self.multibleu_cmd = ['perl', self.config['bleu_script'],
                              self.config['val_set_target'], '<']
        self.compbleu_cmd = [self.config['bleu_script_1'], 
                             self.config['val_set_target'],
                             self.config['val_output_repl']]
        self.ap = afterprocesser(config)

        # Create saving directory if it does not exist
        if not os.path.exists(self.config['saveto']):
            os.makedirs(self.config['saveto'])
Esempio n. 3
0
def main() -> None:
    tokenizer = Tokenizer(args.vocab_file)
    vocabulary_size = len(tokenizer)
    dataset = SentenceDataset(args.input_file, tokenizer=tokenizer.encode)
    loader = DataLoader(dataset,
                        args.batch_size,
                        shuffle=False,
                        collate_fn=dataset.collate_fn,
                        drop_last=False)

    searcher = BeamSearch(tokenizer.eos_index, beam_size=args.search_width)

    model = VAE(
        num_embeddings=len(tokenizer),
        dim_embedding=args.dim_embedding,
        dim_hidden=args.dim_hidden,
        dim_latent=args.dim_latent,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        dropout=0.,
        word_dropout=0.,
        dropped_index=tokenizer.unk_index,
    ).to(device)
    model.load_state_dict(torch.load(args.checkpoint_file,
                                     map_location=device))
    model.eval()

    print('Generating sentence...')
    all_hypotheses = []
    with torch.no_grad():
        for s in tqdm(loader):
            s = s.to(device)
            length = torch.sum(s != tokenizer.pad_index, dim=-1)
            bsz = s.shape[0]

            mean, logvar = model.encode(s, length)
            # z = model.reparameterize(mean, logvar)
            z = mean

            hidden = model.fc_hidden(z)
            hidden = hidden.view(bsz, -1,
                                 model.dim_hidden).transpose(0,
                                                             1).contiguous()

            start_predictions = torch.zeros(bsz, device=device).fill_(
                tokenizer.bos_index).long()
            start_state = {'hidden': hidden.permute(1, 0, 2)}
            predictions, log_probabilities = searcher.search(
                start_predictions, start_state, model.step)

            for preds in predictions:
                tokens = preds[0]
                tokens = tokens[tokens != tokenizer.eos_index].tolist()
                all_hypotheses.append(tokenizer.decode(tokens))
    print('Done')

    with open(args.output_file, 'w') as f:
        f.write('\n'.join(all_hypotheses))
Esempio n. 4
0
def main() -> None:
    tokenizer = Tokenizer(args.vocab_file)
    vocabulary_size = len(tokenizer)

    searcher = BeamSearch(tokenizer.eos_index, beam_size=args.search_width)

    model = VAE(
        num_embeddings=len(tokenizer),
        dim_embedding=args.dim_embedding,
        dim_hidden=args.dim_hidden,
        dim_latent=args.dim_latent,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        dropout=0.,
        word_dropout=0.,
        dropped_index=tokenizer.unk_index,
    ).to(device)
    model.load_state_dict(torch.load(args.checkpoint_file,
                                     map_location=device))
    model.eval()

    sentence1 = input('Please input sentence1: ')
    sentence2 = input('Please input sentence2: ')

    s1 = [tokenizer.bos_index
          ] + tokenizer.encode(sentence1) + [tokenizer.eos_index]
    s2 = [tokenizer.bos_index
          ] + tokenizer.encode(sentence2) + [tokenizer.eos_index]

    z1, _ = model.encode(
        torch.tensor([s1]).to(device),
        torch.tensor([len(s1)]).to(device))
    z2, _ = model.encode(
        torch.tensor([s2]).to(device),
        torch.tensor([len(s2)]).to(device))

    print("\nGenerate intermediate sentences")
    print("      %s" % sentence1)
    for r in range(1, 10):
        z = (1 - 0.1 * r) * z1 + 0.1 * r * z2
        hidden = model.fc_hidden(z)
        hidden = hidden.view(1, -1,
                             model.dim_hidden).transpose(0, 1).contiguous()

        start_predictions = torch.zeros(1, device=device).fill_(
            tokenizer.bos_index).long()
        start_state = {'hidden': hidden.permute(1, 0, 2)}
        predictions, log_probabilities = searcher.search(
            start_predictions, start_state, model.step)

        tokens = predictions[0, 0]
        tokens = tokens[tokens != tokenizer.eos_index].tolist()
        print("[%d:%d] %s" % (10 - r, r, tokenizer.decode(tokens)))
    print("      %s" % sentence2)
Esempio n. 5
0
def translate_model(queue, rqueue, pid, configuration, normalize):

    rng = numpy.random.RandomState(1234)
    enc_dec = EncoderDecoder(rng, **configuration)
    enc_dec.build_sampler()
    enc_dec.load(path=configuration['saveto_best'])
    search_model = BeamSearch(enc_dec=enc_dec, \
                              beam_size=configuration['beam_size'], \
                              maxlen=3*configuration['seq_len_src'], \
                              stochastic=False,
                              configuration=configuration)

    def _translate(seq):
        outputs, scores = search_model.apply(numpy.array(seq).T)

        if normalize:
            lengths = numpy.array([len(s) for s in outputs])
            scores = scores / lengths
        sidx = numpy.argmin(scores)

        return outputs[sidx][:-1]

    while True:
        req = queue.get()
        if req is None:
            break

        idx, x = req[0], req[1]
        print pid, '-', idx

        seq = _translate(x)
        rqueue.put((idx, seq))

    return
Esempio n. 6
0
    def translate(self, bisents, beam_size, max_output_len, length_norm_alpha,
                  output_file, relative, absolute, local, candidate):
        avg_fan_outs = []
        total_fan_outs = []
        with open(output_file, 'w') as output:
            for i in range(len(bisents)):
                print("Translating sentence", i)
                src_sent = bisents[i][0]
                dy.renew_cg()
                self.encode([src_sent])
                self.decoder.init(
                    dy.affine_transform([
                        dy.parameter(self.b_bridge),
                        dy.parameter(self.W_bridge),
                        self.encoder.final_state()
                    ]))

                beam_search = BeamSearch(beam_size, max_output_len,
                                         length_norm_alpha)
                beam_search.set_pruning_strategy(relative, absolute, local,
                                                 candidate)
                k_best_output, avg_fan_out, total_fan_out, num_pruned = beam_search.search(
                    self)

                print("pruned:", num_pruned)
                print("avg fan out:", avg_fan_out)
                print("total fan out:", total_fan_out)

                # remove start and end symbols
                words = k_best_output[1:-1] if k_best_output[
                    -1] == self.tgt_vocab.eos else k_best_output[1:]
                output_sent = [self.tgt_vocab.i2w[word] for word in words]
                avg_fan_outs.append(avg_fan_out)
                total_fan_outs.append(total_fan_out)
                output.write(" ".join(output_sent) + '\n')
                if (i + 1) % 100 == 0:
                    output.flush()
        print("avg avg fan out:", sum(avg_fan_outs) / len(avg_fan_outs))
        print("avg total fan out:", sum(total_fan_outs) / len(total_fan_outs))
Esempio n. 7
0
def main() -> None:
    tokenizer = Tokenizer(args.vocab_file)
    vocabulary_size = len(tokenizer)

    searcher = BeamSearch(tokenizer.eos_index, beam_size=args.search_width)

    model = VAE(
        num_embeddings=len(tokenizer),
        dim_embedding=args.dim_embedding,
        dim_hidden=args.dim_hidden,
        dim_latent=args.dim_latent,
        num_layers=args.num_layers,
        bidirectional=args.bidirectional,
        dropout=0.,
        word_dropout=0.,
        dropped_index=tokenizer.unk_index,
    ).to(device)
    model.load_state_dict(torch.load(args.checkpoint_file,
                                     map_location=device))
    model.eval()

    z = torch.randn(args.sample_size, args.dim_latent, device=device)
    hidden = model.fc_hidden(z)
    hidden = hidden.view(args.sample_size, -1,
                         model.dim_hidden).transpose(0, 1).contiguous()

    start_predictions = torch.zeros(args.sample_size, device=device).fill_(
        tokenizer.bos_index).long()
    start_state = {'hidden': hidden.permute(1, 0, 2)}
    predictions, log_probabilities = searcher.search(start_predictions,
                                                     start_state, model.step)

    for pred in predictions:
        tokens = pred[0]
        tokens = tokens[tokens != tokenizer.eos_index].tolist()
        print(tokenizer.decode(tokens))
Esempio n. 8
0
    def __init__(self, args, word2index, char2index, device):
        super(MS_Pointer, self).__init__()
        self.args = args
        self.device = device
        self.lr = args.lr
        self.hidden_dim = args.hidden_dim
        self.batch_size = args.batch_size
        self.embedding_dim = args.emb_dim
        self.target_embedding_dim = self.embedding_dim
        self.dropout_rate = 1.0 - args.dropout_keep_prob

        self.word2index = word2index
        self.char2index = char2index
        self.index2word = {v: k for k, v in self.word2index.items()}
        self.index2char = {v: k for k, v in self.char2index.items()}
        self.char_num = len(self.char2index)
        self.word_num = len(self.word2index)

        self.max_seq_len = args.max_seq_len
        self.max_token_len = args.max_token_len
        self.max_decoding_steps = args.max_decoding_steps

        self.flag_use_layernorm = args.flag_use_layernorm
        self.flag_use_dropout = args.flag_use_dropout
        self.flag_use_position_embedding = args.use_position_emb

        self.encoder_output_dim_1 = args.encoder_output_dim_1
        self.encoder_output_dim_2 = args.encoder_output_dim_2
        self.cated_encoder_out_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2

        self.decoder_output_dim = args.decoder_output_dim
        self.decoder_input_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2 + self.target_embedding_dim

        # Word Embedding, Char Embedding and Positional Embeddings.
        self.char_embeddings = nn.Embedding(self.char_num,
                                            self.embedding_dim,
                                            padding_idx=self.args.PAD_idx)
        self.word_embeddings = nn.Embedding(self.word_num,
                                            self.embedding_dim,
                                            padding_idx=self.args.PAD_idx)
        if self.flag_use_position_embedding:
            self.position_embeddings = PositionalEmbedding(
                self.embedding_dim, self.max_seq_len)

        # Char encoder layer and word encoder layer for source1 and source2 respectively.
        self.source1_words_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)
        self.source1_chars_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)
        self.source2_words_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)
        self.source2_chars_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)

        # Char transformer layer and word transformer layer for source1 and source2 respectively.
        self.source1_words_transformer_encoder = nn.TransformerEncoder(
            self.source1_words_encoder_layer, num_layers=2)
        self.source1_chars_transformer_encoder = nn.TransformerEncoder(
            self.source1_chars_encoder_layer, num_layers=1)
        self.source2_words_transformer_encoder = nn.TransformerEncoder(
            self.source2_words_encoder_layer, num_layers=2)
        self.source2_chars_transformer_encoder = nn.TransformerEncoder(
            self.source2_chars_encoder_layer, num_layers=1)

        self.source1_attention_layer = AdditiveAttention(
            self.hidden_dim, self.encoder_output_dim_1)
        self.source2_attention_layer = AdditiveAttention(
            self.hidden_dim, self.encoder_output_dim_2)

        self.source1_dropout_layer = nn.Dropout(p=self.dropout_rate)
        self.source2_dropout_layer = nn.Dropout(p=self.dropout_rate)
        self.encoder_out_projection_layer = nn.Linear(
            in_features=self.cated_encoder_out_dim,
            out_features=self.decoder_output_dim)

        self.gate_projection_layer = torch.nn.Linear(
            in_features=self.decoder_output_dim + self.decoder_input_dim,
            out_features=1,
            bias=False)
        self.decoder_cell = nn.modules.LSTMCell(
            input_size=self.decoder_input_dim,
            hidden_size=self.decoder_output_dim,
            bias=True)
        self.beam_search = BeamSearch(self.max_seq_len * 2 - 1,
                                      max_steps=self.max_decoding_steps,
                                      beam_size=self.args.beam_size)
        self.test_data_utils = TestDataUtils(self.args, self.word2index,
                                             self.char2index)

        if self.flag_use_layernorm:
            self.source1_encoder_layernorm = nn.LayerNorm(
                normalized_shape=[self.max_seq_len, self.embedding_dim])
            self.source2_encoder_layernorm = nn.LayerNorm(
                normalized_shape=[self.max_seq_len, self.embedding_dim])
            self.decoder_hidden_layernorm = nn.LayerNorm(
                normalized_shape=self.decoder_output_dim)
            self.decoder_cell_layernorm = nn.LayerNorm(
                normalized_shape=self.decoder_output_dim)
Esempio n. 9
0
class MS_Pointer(nn.Module):
    """
    MS-Pointer Network: (Multiple source pointer network), In this demo, utilizing two sources, 
    """
    def __init__(self, args, word2index, char2index, device):
        super(MS_Pointer, self).__init__()
        self.args = args
        self.device = device
        self.lr = args.lr
        self.hidden_dim = args.hidden_dim
        self.batch_size = args.batch_size
        self.embedding_dim = args.emb_dim
        self.target_embedding_dim = self.embedding_dim
        self.dropout_rate = 1.0 - args.dropout_keep_prob

        self.word2index = word2index
        self.char2index = char2index
        self.index2word = {v: k for k, v in self.word2index.items()}
        self.index2char = {v: k for k, v in self.char2index.items()}
        self.char_num = len(self.char2index)
        self.word_num = len(self.word2index)

        self.max_seq_len = args.max_seq_len
        self.max_token_len = args.max_token_len
        self.max_decoding_steps = args.max_decoding_steps

        self.flag_use_layernorm = args.flag_use_layernorm
        self.flag_use_dropout = args.flag_use_dropout
        self.flag_use_position_embedding = args.use_position_emb

        self.encoder_output_dim_1 = args.encoder_output_dim_1
        self.encoder_output_dim_2 = args.encoder_output_dim_2
        self.cated_encoder_out_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2

        self.decoder_output_dim = args.decoder_output_dim
        self.decoder_input_dim = self.encoder_output_dim_1 + self.encoder_output_dim_2 + self.target_embedding_dim

        # Word Embedding, Char Embedding and Positional Embeddings.
        self.char_embeddings = nn.Embedding(self.char_num,
                                            self.embedding_dim,
                                            padding_idx=self.args.PAD_idx)
        self.word_embeddings = nn.Embedding(self.word_num,
                                            self.embedding_dim,
                                            padding_idx=self.args.PAD_idx)
        if self.flag_use_position_embedding:
            self.position_embeddings = PositionalEmbedding(
                self.embedding_dim, self.max_seq_len)

        # Char encoder layer and word encoder layer for source1 and source2 respectively.
        self.source1_words_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)
        self.source1_chars_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)
        self.source2_words_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)
        self.source2_chars_encoder_layer = nn.TransformerEncoderLayer(
            d_model=self.embedding_dim, nhead=4)

        # Char transformer layer and word transformer layer for source1 and source2 respectively.
        self.source1_words_transformer_encoder = nn.TransformerEncoder(
            self.source1_words_encoder_layer, num_layers=2)
        self.source1_chars_transformer_encoder = nn.TransformerEncoder(
            self.source1_chars_encoder_layer, num_layers=1)
        self.source2_words_transformer_encoder = nn.TransformerEncoder(
            self.source2_words_encoder_layer, num_layers=2)
        self.source2_chars_transformer_encoder = nn.TransformerEncoder(
            self.source2_chars_encoder_layer, num_layers=1)

        self.source1_attention_layer = AdditiveAttention(
            self.hidden_dim, self.encoder_output_dim_1)
        self.source2_attention_layer = AdditiveAttention(
            self.hidden_dim, self.encoder_output_dim_2)

        self.source1_dropout_layer = nn.Dropout(p=self.dropout_rate)
        self.source2_dropout_layer = nn.Dropout(p=self.dropout_rate)
        self.encoder_out_projection_layer = nn.Linear(
            in_features=self.cated_encoder_out_dim,
            out_features=self.decoder_output_dim)

        self.gate_projection_layer = torch.nn.Linear(
            in_features=self.decoder_output_dim + self.decoder_input_dim,
            out_features=1,
            bias=False)
        self.decoder_cell = nn.modules.LSTMCell(
            input_size=self.decoder_input_dim,
            hidden_size=self.decoder_output_dim,
            bias=True)
        self.beam_search = BeamSearch(self.max_seq_len * 2 - 1,
                                      max_steps=self.max_decoding_steps,
                                      beam_size=self.args.beam_size)
        self.test_data_utils = TestDataUtils(self.args, self.word2index,
                                             self.char2index)

        if self.flag_use_layernorm:
            self.source1_encoder_layernorm = nn.LayerNorm(
                normalized_shape=[self.max_seq_len, self.embedding_dim])
            self.source2_encoder_layernorm = nn.LayerNorm(
                normalized_shape=[self.max_seq_len, self.embedding_dim])
            self.decoder_hidden_layernorm = nn.LayerNorm(
                normalized_shape=self.decoder_output_dim)
            self.decoder_cell_layernorm = nn.LayerNorm(
                normalized_shape=self.decoder_output_dim)

    def get_target_token_embeddings(self, target_token_ids):
        """
        args:
            target_token_ids: (tuple), (target_token_word_ids, target_token_char_ids)
        return:
            target_token_embeddings: (torch.tensor),  Batch x Length x Dim
        """
        def get_char_idx_from_words(target_word_ids):
            target_char_ids = []

            for single_instance_word in target_word_ids.tolist():
                word = self.index2word[single_instance_word]

                if word in [
                        self.args.PAD_token, self.args.BOS_token,
                        self.args.EOS_token, self.args.OOV_token
                ]:
                    char_ids = self.max_token_len * [
                        self.char2index.get(word, self.args.OOV_idx)
                    ]
                else:
                    char_ids = [
                        self.char2index.get(char, self.args.OOV_idx)
                        for char in word
                    ]
                    char_ids = char_ids[0:self.max_token_len] if self.max_token_len < len(char_ids) \
                        else char_ids + (self.max_token_len - len(char_ids)) * [self.args.PAD_idx]
                target_char_ids.append(char_ids)

            return torch.LongTensor(target_char_ids).to(self.device)

        if isinstance(target_token_ids, tuple):
            target_words_ids, target_chars_ids = target_token_ids
        else:
            target_words_ids = target_token_ids
            target_chars_ids = get_char_idx_from_words(target_words_ids)

        target_word_embeddings = self.word_embeddings(target_words_ids)
        target_char_embeddings = self.char_embeddings(target_chars_ids).sum(1)
        target_token_embeddings = target_word_embeddings + target_char_embeddings
        return target_token_embeddings

    def encode(self, batch_input):
        source1_input_words_ids = batch_input["source1_input_words_ids"]
        source1_input_chars_ids = batch_input["source1_input_chars_ids"]
        source2_input_words_ids = batch_input["source2_input_words_ids"]
        source2_input_chars_ids = batch_input["source2_input_chars_ids"]
        source1_input_seq_len = batch_input["source1_input_seq_len"]
        source2_input_seq_len = batch_input["source2_input_seq_len"]

        source1_words_embs = self.word_embeddings(source1_input_words_ids)
        source1_chars_embs = self.char_embeddings(source1_input_chars_ids).sum(
            2)
        source2_words_embs = self.word_embeddings(source2_input_words_ids)
        source2_chars_embs = self.char_embeddings(source2_input_chars_ids).sum(
            2)

        if self.flag_use_position_embedding:
            source1_words_embs = source1_words_embs + self.position_embeddings(
                source1_input_seq_len)
            source1_chars_embs = source1_chars_embs + self.position_embeddings(
                source1_input_seq_len)
            source2_words_embs = source2_words_embs + self.position_embeddings(
                source2_input_seq_len)
            source2_chars_embs = source2_chars_embs + self.position_embeddings(
                source2_input_seq_len)

        source1_words_transformer_output = self.source1_words_transformer_encoder(
            source1_words_embs)
        source1_chars_transformer_output = self.source1_chars_transformer_encoder(
            source1_chars_embs)
        source2_words_transformer_output = self.source2_words_transformer_encoder(
            source2_words_embs)
        source2_chars_transformer_output = self.source2_chars_transformer_encoder(
            source2_chars_embs)

        source1_encoder_output = source1_words_transformer_output + source1_chars_transformer_output
        source2_encoder_output = source2_words_transformer_output + source2_chars_transformer_output

        if self.flag_use_layernorm:
            source1_encoder_output = self.source1_encoder_layernorm(
                source1_encoder_output)
            source2_encoder_output = self.source2_encoder_layernorm(
                source2_encoder_output)

        if self.flag_use_dropout:
            source1_encoder_output = self.source1_dropout_layer(
                source1_encoder_output)
            source2_encoder_output = self.source2_dropout_layer(
                source2_encoder_output)

        initial_decoder_hidden_state = torch.tanh(
            self.encoder_out_projection_layer(
                torch.cat([
                    source1_encoder_output[:, 0, :],
                    source2_encoder_output[:, 0, :]
                ],
                          dim=-1)))
        return source1_encoder_output, source2_encoder_output, initial_decoder_hidden_state

    def get_initial_model_state(self, batch_input):

        model_state = {}
        model_state["merged_source_global_ids"] = batch_input[
            "merged_source_global_ids"]
        model_state["merged_source_local_ids"] = batch_input[
            "merged_source_local_ids"]
        model_state["source1_local_words_ids"] = batch_input[
            "source1_local_words_ids"]
        model_state["source2_local_words_ids"] = batch_input[
            "source2_local_words_ids"]

        batch_size = batch_input["source1_input_words_ids"].shape[0]

        source1_encoder_output, source2_encoder_output, initial_decoder_hidden = self.encode(
            batch_input)
        # initial_decoder_cell = torch.rand(batch_size, self.decoder_output_dim)
        initial_decoder_cell = initial_decoder_hidden.new_zeros(
            batch_size, self.decoder_output_dim)

        model_state["decoder_hidden_state"] = initial_decoder_hidden
        model_state["decoder_hidden_cell"] = initial_decoder_cell
        model_state["source1_encoder_output"] = source1_encoder_output
        model_state["source2_encoder_output"] = source2_encoder_output

        initial_source1_decoder_attention = self.source1_attention_layer(
            initial_decoder_hidden, source1_encoder_output[:, 0:, :])
        initial_source2_decoder_attention = self.source2_attention_layer(
            initial_decoder_hidden, source2_encoder_output[:, 0:, :])

        initial_source1_decoder_attention_score = torch.softmax(
            initial_source1_decoder_attention, -1)
        initial_source2_decoder_attention_score = torch.softmax(
            initial_source2_decoder_attention, -1)

        initial_source1_weighted_context = weighted_sum(
            source1_encoder_output, initial_source1_decoder_attention_score)
        initial_source2_weighted_context = weighted_sum(
            source2_encoder_output, initial_source2_decoder_attention_score)
        model_state[
            "source1_weighted_context"] = initial_source1_weighted_context
        model_state[
            "source2_weighted_context"] = initial_source2_weighted_context

        return model_state

    def decode_step(self, previous_token_ids, model_state):
        # Fetch last timestep values.
        previous_source1_weighted_context = model_state[
            "source1_weighted_context"]
        previous_source2_weighted_context = model_state[
            "source2_weighted_context"]
        previous_decoder_hidden_state = model_state["decoder_hidden_state"]
        previous_decoder_hidden_cell = model_state["decoder_hidden_cell"]
        previous_token_embedding = self.get_target_token_embeddings(
            previous_token_ids)

        # update decoder hidden state of current timestep
        current_decoder_input = torch.cat(
            (previous_token_embedding, previous_source1_weighted_context,
             previous_source2_weighted_context),
            dim=-1)
        decoder_hidden_state, decoder_hidden_cell = self.decoder_cell(
            current_decoder_input,
            (previous_decoder_hidden_state, previous_decoder_hidden_cell))
        # print(decoder_hidden_state.shape, decoder_hidden_cell.shape)
        if self.flag_use_layernorm:
            decoder_hidden_state = self.decoder_hidden_layernorm(
                decoder_hidden_state)
            decoder_hidden_cell = self.decoder_cell_layernorm(
                decoder_hidden_cell)
        model_state["decoder_hidden_state"] = decoder_hidden_state
        model_state["decoder_hidden_cell"] = decoder_hidden_cell

        # Computing decoder's attention score on encoder output.
        source1_encoder_output = model_state["source1_encoder_output"]
        source2_encoder_output = model_state["source2_encoder_output"]
        source1_decoder_attention_output = self.source1_attention_layer(
            decoder_hidden_state, source1_encoder_output)
        source2_decoder_attention_output = self.source2_attention_layer(
            decoder_hidden_state, source2_encoder_output)

        # print("attention dim: ", source1_decoder_attention_output.shape)
        source1_decoder_attention_score = torch.softmax(
            source1_decoder_attention_output, -1)
        source2_decoder_attention_score = torch.softmax(
            source2_decoder_attention_output, -1)
        model_state[
            "source1_decoder_attention_score"] = source1_decoder_attention_score
        model_state[
            "source2_decoder_attention_score"] = source2_decoder_attention_score

        # context vector of source1 and source2, weighted sum of (source encoder output) * decoder attention score.
        # source1_weighted_context = weighted_sum(source1_encoder_output[:,1:, :], source1_decoder_attention_score)
        # source2_weighted_context = weighted_sum(source2_encoder_output[:,1:, :], source2_decoder_attention_score)
        source1_weighted_context = weighted_sum(
            source1_encoder_output, source1_decoder_attention_score)
        source2_weighted_context = weighted_sum(
            source2_encoder_output, source2_decoder_attention_score)
        model_state["source1_weighted_context"] = source1_weighted_context
        model_state["source2_weighted_context"] = source2_weighted_context

        # Computing current gate socre.
        gate_input = torch.cat(
            (previous_token_embedding, source1_weighted_context,
             source2_weighted_context, decoder_hidden_state),
            dim=-1)
        gate_projected = self.gate_projection_layer(gate_input).squeeze(-1)
        gate_score = torch.sigmoid(gate_projected)
        model_state["gate_score"] = gate_score

        return model_state

    def get_batch_loss(self, batch_input):
        # source1_token_mask = batch_input["source1_input_seq_mask"]
        # source2_token_mask = batch_input["source2_input_seq_mask"]
        target_words_ids = batch_input["target_words_ids"]
        target_chars_ids = batch_input["target_chars_ids"]
        # target_mask = batch_input["target_seq_mask"]

        batch_size, target_seq_len = target_words_ids.size()
        num_decoding_steps = target_seq_len - 1
        model_state = self.get_initial_model_state(batch_input)

        step_log_likelihoods = []  # 存放每个时间步,目标词的log似然值
        for timestep in range(num_decoding_steps):
            previous_token_ids = (target_words_ids[:, timestep],
                                  target_chars_ids[:, timestep, :])

            model_state = self.decode_step(previous_token_ids, model_state)

            target_to_source1 = (batch_input["source1_input_words_ids"] ==
                                 target_words_ids[:,
                                                  timestep + 1].unsqueeze(-1))
            target_to_source2 = (batch_input["source2_input_words_ids"] ==
                                 target_words_ids[:,
                                                  timestep + 1].unsqueeze(-1))

            step_log_likelihood = self.get_negative_log_likelihood(
                model_state["source1_decoder_attention_score"],
                model_state["source2_decoder_attention_score"],
                target_to_source1, target_to_source2,
                model_state["gate_score"])

            step_log_likelihoods.append(step_log_likelihood.unsqueeze(-1))

        # 将各个时间步的对数似然合并成一个tensor
        # shape: (batch_size, num_decoding_steps = target_seq_len - 1)
        log_likelihoods = torch.cat(step_log_likelihoods, -1)

        # 去掉第一个,不会作为目标词的START
        # shape: (batch_size, num_decoding_steps = target_seq_len - 1)
        # target_mask = target_mask[:, 1:].float()

        # 将各个时间步上的对数似然tensor使用mask累加,得到整个时间序列的对数似然
        # log_likelihood = (log_likelihoods * target_mask)  # .sum(dim=-1)
        log_likelihood = log_likelihoods.sum(dim=-1)
        batch_loss = -log_likelihood.sum()
        mean_loss = batch_loss / batch_size

        return {"mean_loss": mean_loss, "batch_loss": batch_loss}

    def get_negative_log_likelihood(self, source1_decoder_attention_score,
                                    source2_decoder_attention_score,
                                    target_to_source1, target_to_source2,
                                    gate_score):

        # shape: (batch_size, seq_max_len_1)
        combined_log_probs_1 = ((source1_decoder_attention_score *
                                 target_to_source1.float()).sum(-1) +
                                1e-20).log()

        # shape: (batch_size, seq_max_len_2)
        combined_log_probs_2 = ((source2_decoder_attention_score *
                                 target_to_source2.float()).sum(-1) +
                                1e-20).log()

        # 计算 log(p1 * gate + p2 * (1-gate))
        log_gate_score_1 = (gate_score + 1e-20).log()  # shape: (batch_size,)
        log_gate_score_2 = (1 - gate_score +
                            1e-20).log()  # shape: (batch_size,)

        item_1 = (log_gate_score_1 + combined_log_probs_1).unsqueeze(-1)
        item_2 = (log_gate_score_2 + combined_log_probs_2).unsqueeze(-1)
        step_log_likelihood = logsumexp(torch.cat((item_1, item_2), -1))
        return step_log_likelihood

    def merge_final_log_probs(self, source1_decoder_attention_score,
                              source2_decoder_attention_score,
                              source1_local_words_ids, source2_local_words_ids,
                              gate_score):
        """
        根据三个概率,计算全词表上的对数似然。
        """
        # 获取group_size和两个序列的长度
        group_size, seq_max_len_1 = source1_decoder_attention_score.size()
        group_size, seq_max_len_2 = source2_decoder_attention_score.size()

        # 需要和source1相乘的gate概率,shape: (group_size, seq_max_len_1)
        gate_1 = gate_score.expand(seq_max_len_1, -1).t()
        # 需要和source2相乘的gate概率,shape: (group_size, seq_max_len_2)
        gate_2 = (1 - gate_score).expand(seq_max_len_2, -1).t()

        # 加权后的source1分值,shape: (group_size, seq_max_len_1)
        source1_decoder_attention_score = source1_decoder_attention_score * gate_1
        # 加权后的source2分值,shape: (group_size, seq_max_len_2)
        source2_decoder_attention_score = source2_decoder_attention_score * gate_2

        # shape: (group_size, seq_max_len_1)
        log_probs_1 = (source1_decoder_attention_score + 1e-45).log()
        # shape: (group_size, seq_max_len_2)
        log_probs_2 = (source2_decoder_attention_score + 1e-45).log()

        # 初始化全词表上的概率为全0, shape: (group_size, target_vocab_size)
        final_log_probs = (source1_decoder_attention_score.new_zeros(
            (group_size, 2 * self.max_seq_len)) + 1e-45).log()

        for i in range(seq_max_len_1):  # 遍历source1的所有时间步
            # 当前时间步的预测概率,shape: (group_size, 1)
            log_probs_slice = log_probs_1[:, i].unsqueeze(-1)
            # 当前时间步的token ids,shape: (group_size, 1)
            source_to_target_slice = source1_local_words_ids[:,
                                                             i].unsqueeze(-1)

            # 选出要更新位置,原有的词表概率,shape: (group_size, 1)
            # print(source_to_target_slice.shape,"\t",final_log_probs.shape)
            selected_log_probs = final_log_probs.gather(
                -1, source_to_target_slice)
            # 更新后的概率值(原有概率+更新概率,混合),shape: (group_size, 1)
            combined_scores = logsumexp(
                torch.cat((selected_log_probs, log_probs_slice),
                          dim=-1)).unsqueeze(-1)
            # 将combined_scores设置回final_log_probs中
            final_log_probs = final_log_probs.scatter(-1,
                                                      source_to_target_slice,
                                                      combined_scores)

        # 对source2也同样做一遍
        for i in range(seq_max_len_2):
            log_probs_slice = log_probs_2[:, i].unsqueeze(-1)
            source_to_target_slice = source2_local_words_ids[:,
                                                             i].unsqueeze(-1)
            selected_log_probs = final_log_probs.gather(
                -1, source_to_target_slice)
            combined_scores = logsumexp(
                torch.cat((selected_log_probs, log_probs_slice),
                          dim=-1)).unsqueeze(-1)
            final_log_probs = final_log_probs.scatter(-1,
                                                      source_to_target_slice,
                                                      combined_scores)

        return final_log_probs

    def take_search_step(self, previous_token_ids, model_state):
        # 更新一步decoder状态
        # model_state = self.get_initial_model_state(batch_input)
        model_state = self.decode_step(previous_token_ids, model_state)

        # 计算两个source的对数似然的合并结果
        final_log_probs = self.merge_final_log_probs(
            model_state["source1_decoder_attention_score"],
            model_state["source2_decoder_attention_score"],
            model_state["source1_local_words_ids"],
            model_state["source2_local_words_ids"], model_state["gate_score"])
        return final_log_probs, model_state

    def forward_beam_search(self, batch_input, model_state):
        source1_input_words_ids = batch_input["source1_input_words_ids"]
        # merged_source_local_ids = batch_input["merged_source_local_ids"]

        batch_size = source1_input_words_ids.size()[0]
        start_token_ids = source1_input_words_ids.new_full(
            (batch_size, ), fill_value=self.args.BOS_idx)

        all_top_k_predictions, log_probabilities = self.beam_search.search(
            start_token_ids, batch_input, model_state, self.take_search_step)

        return {
            "predicted_log_probs": log_probabilities,
            "predicted_token_ids": all_top_k_predictions
        }

    def get_predicted_tokens(self, predicted_token_ids,
                             merged_source_word_list):

        word_list_len = merged_source_word_list.shape[1]
        batch_size, beam_size, target_len = predicted_token_ids.shape

        expanded_word_list = merged_source_word_list.reshape(
            batch_size, 1, word_list_len)
        expanded_word_list = np.tile(expanded_word_list, (1, beam_size, 1))

        dim0_indexer = np.tile(
            np.array(range(batch_size)).reshape(batch_size, 1, 1),
            (1, beam_size, target_len))
        dim1_indexer = np.tile(
            np.array(range(beam_size)).reshape(1, beam_size, 1),
            (batch_size, 1, target_len))
        dim2_indexer = predicted_token_ids.cpu()

        predicted_tokens = expanded_word_list[dim0_indexer, dim1_indexer,
                                              dim2_indexer]

        return predicted_tokens

    def predict_single_instance(self, instance):
        """
        instance: List, [[source1 words], [source2 words]]
        """
        raise NotImplementedError

    def predict_single_batch(self, batch_input):
        self.eval()

        with torch.no_grad():
            model_state = self.get_initial_model_state(batch_input)
            pred_result = self.forward_beam_search(batch_input, model_state)
            predicted_token_ids = pred_result["predicted_token_ids"]
            predicted_log_probs = pred_result["predicted_log_probs"]

            # merged_source_word_list = batch_input["merged_source_word_list"]
            predicted_tokens = self.get_predicted_tokens(
                predicted_token_ids, batch_input["merged_source_word_list"])

        return predicted_tokens, predicted_log_probs

    def predict(self, raw_test_data):
        """
        args:
            raw_test_data
                raw_test_data should be a three-order list:
                [
                    [[instance-1 source-1 words], [instance-1 source-2 words]],
                    [[instance-2 source-1 words], [instance-2 source-2 words]],
                    ... ...
                    [[instance-n source-1 words], [instance-n source-2 words]],
                ]
        return:
            all_batch_predicted_tokens:
                a three-order list with shape (batch_size, beam_size, target_length)
            all_batch_predicted_probs :
                a three-order list with shape (batch_size, beam_size, target_length)
        """

        # Set self.training as 'False' when predict to stop updating running variables of normalization or dropout
        self.eval()

        all_batch_test_data = self.test_data_utils.get_batch_formatted_test_data(
            raw_test_data, device=self.device)
        with torch.no_grad():
            all_batch_predicted_tokens = []
            all_batch_predicted_probs = []

            for batch_input in all_batch_test_data:
                predicted_tokens, predicted_log_probs = self.predict_single_batch(
                    batch_input)
                all_batch_predicted_tokens += predicted_tokens.tolist()
                all_batch_predicted_probs += predicted_log_probs.softmax(
                    -1).tolist()

        return all_batch_predicted_tokens, all_batch_predicted_probs

    def valid_single_batch(self, batch_input, need_pred_result):
        self.eval()
        with torch.no_grad():
            valid_loss = self.get_batch_loss(batch_input)

            if need_pred_result:
                predicted_tokens, predicted_log_probs = self.predict_single_batch(
                    batch_input)
            else:
                predicted_tokens, predicted_log_probs = None, None

        return valid_loss, predicted_tokens, predicted_log_probs

    def validation(self, all_batch_data, need_pred_result):
        all_batch_predicted_tokens = []
        all_batch_predicted_probs = []
        batch_size = len(all_batch_data)
        all_batch_loss = 0.0
        all_batch_bleu = 0.0

        batch_generator = tqdm(all_batch_data, ncols=100)

        if need_pred_result:
            for idx, batch in enumerate(batch_generator):
                batch_start_time = time.time()
                valid_loss, predicted_tokens, predicted_log_probs = self.valid_single_batch(
                    batch, need_pred_result)

                mean_loss = valid_loss["mean_loss"].detach().cpu().item()
                all_batch_predicted_tokens += predicted_tokens.tolist()
                all_batch_predicted_probs += predicted_log_probs.tolist()
                all_batch_loss += valid_loss["batch_loss"].detach().cpu().item(
                )

                pred_corpus = [[
                    word_list[0:2]
                ] for word_list in predicted_tokens[:, 0, :].tolist()]
                bleu_score = corpus_bleu(pred_corpus,
                                         batch["target_word_list"],
                                         weights=[0.5, 0.5])
                all_batch_bleu += bleu_score
                batch_elapsed_time = round(time.time() - batch_start_time, 2)

                info = f"{color('[Valid]', 1)} Batch:{color(idx, 2)}  BLEU:{color(round(bleu_score, 5), 1)} " \
                       f"Loss:{color(round(mean_loss, 5), 1)} Time:{color(batch_elapsed_time, 2)}"
                batch_generator.set_description(desc=info, refresh=True)
        else:
            for idx, batch in enumerate(batch_generator):
                batch_start_time = time.time()
                valid_loss, predicted_tokens, predicted_log_probs = self.valid_single_batch(
                    batch, need_pred_result)
                mean_loss = valid_loss["mean_loss"].detach().cpu().item()
                all_batch_loss += valid_loss["batch_loss"].detach().cpu().item(
                )

                bleu_score = "None"
                batch_elapsed_time = round(time.time() - batch_start_time, 2)

                info = f"{color('[Valid]', 1)} Batch:{color(idx, 2)}  BLEU:{color(bleu_score, 1)} " \
                       f"Loss:{color(round(mean_loss, 5), 1)} Time:{color(batch_elapsed_time, 2)}"
                batch_generator.set_description(desc=info, refresh=True)

        mean_blue = all_batch_bleu / batch_size
        return all_batch_loss, mean_blue, all_batch_predicted_tokens, all_batch_predicted_probs

    def __call__(self, raw_test_data):
        """
        Call the predict function.
        """
        return self.predict(raw_test_data)

    def __enter__(self, *args, **kwargs):
        return self

    def __exit__(self, exc_type, exc_value, traceback):
        return
Esempio n. 10
0
def main(mode, config, use_bokeh=False):

    # Construct model
    logger.info('Building RNN encoder-decoder')
    encoder = BidirectionalEncoder(config['src_vocab_size'],
                                   config['enc_embed'], config['enc_nhids'])
    decoder = Decoder(config['trg_vocab_size'], config['dec_embed'],
                      config['dec_nhids'], config['enc_nhids'] * 2,
                      config['topical_embedding_dim'])
    topical_transformer = topicalq_transformer(config['topical_vocab_size'],
                                               config['topical_embedding_dim'],
                                               config['enc_nhids'],
                                               config['topical_word_num'],
                                               config['batch_size'])

    if mode == "train":

        # Create Theano variables
        logger.info('Creating theano variables')
        source_sentence = tensor.lmatrix('source')
        source_sentence_mask = tensor.matrix('source_mask')
        target_sentence = tensor.lmatrix('target')
        target_sentence_mask = tensor.matrix('target_mask')
        sampling_input = tensor.lmatrix('input')
        source_topical_word = tensor.lmatrix('source_topical')
        source_topical_mask = tensor.matrix('source_topical_mask')

        # Get training and development set streams
        tr_stream = get_tr_stream_with_topicalq(**config)
        dev_stream = get_dev_stream_with_topicalq(**config)
        topic_embedding = topical_transformer.apply(source_topical_word)
        # Get cost of the model
        representation = encoder.apply(source_sentence, source_sentence_mask)
        tw_representation = topical_transformer.look_up.apply(
            source_topical_word.T)
        content_embedding = representation[0, :,
                                           (representation.shape[2] / 2):]

        cost = decoder.cost(representation, source_sentence_mask,
                            tw_representation, source_topical_mask,
                            target_sentence, target_sentence_mask,
                            topic_embedding, content_embedding)

        logger.info('Creating computational graph')
        cg = ComputationGraph(cost)

        # Initialize model
        logger.info('Initializing model')
        encoder.weights_init = decoder.weights_init = IsotropicGaussian(
            config['weight_scale'])
        encoder.biases_init = decoder.biases_init = Constant(0)
        encoder.push_initialization_config()
        decoder.push_initialization_config()
        encoder.bidir.prototype.weights_init = Orthogonal()
        decoder.transition.weights_init = Orthogonal()
        encoder.initialize()
        decoder.initialize()
        topical_transformer.weights_init = IsotropicGaussian(
            config['weight_scale'])
        topical_transformer.biases_init = Constant(0)
        topical_transformer.push_allocation_config()
        #don't know whether the initialize is for
        topical_transformer.look_up.weights_init = Orthogonal()
        topical_transformer.transformer.weights_init = Orthogonal()
        topical_transformer.initialize()
        word_topical_embedding = cPickle.load(
            open(config['topical_embeddings'], 'rb'))
        np_word_topical_embedding = numpy.array(word_topical_embedding,
                                                dtype='float32')
        topical_transformer.look_up.W.set_value(np_word_topical_embedding)
        topical_transformer.look_up.W.tag.role = []

        # apply dropout for regularization
        if config['dropout'] < 1.0:
            # dropout is applied to the output of maxout in ghog
            logger.info('Applying dropout')
            dropout_inputs = [
                x for x in cg.intermediary_variables
                if x.name == 'maxout_apply_output'
            ]
            cg = apply_dropout(cg, dropout_inputs, config['dropout'])

        # Apply weight noise for regularization
        if config['weight_noise_ff'] > 0.0:
            logger.info('Applying weight noise to ff layers')
            enc_params = Selector(encoder.lookup).get_params().values()
            enc_params += Selector(encoder.fwd_fork).get_params().values()
            enc_params += Selector(encoder.back_fork).get_params().values()
            dec_params = Selector(
                decoder.sequence_generator.readout).get_params().values()
            dec_params += Selector(
                decoder.sequence_generator.fork).get_params().values()
            dec_params += Selector(decoder.state_init).get_params().values()
            cg = apply_noise(cg, enc_params + dec_params,
                             config['weight_noise_ff'])

        # Print shapes
        shapes = [param.get_value().shape for param in cg.parameters]
        logger.info("Parameter shapes: ")
        for shape, count in Counter(shapes).most_common():
            logger.info('    {:15}: {}'.format(shape, count))
        logger.info("Total number of parameters: {}".format(len(shapes)))

        # Print parameter names
        enc_dec_param_dict = merge(
            Selector(encoder).get_parameters(),
            Selector(decoder).get_parameters())
        logger.info("Parameter names: ")
        for name, value in enc_dec_param_dict.items():
            logger.info('    {:15}: {}'.format(value.get_value().shape, name))
        logger.info("Total number of parameters: {}".format(
            len(enc_dec_param_dict)))

        # Set up training model
        logger.info("Building model")
        training_model = Model(cost)

        # Set extensions
        logger.info("Initializing extensions")
        extensions = [
            FinishAfter(after_n_batches=config['finish_after']),
            TrainingDataMonitoring([cost], after_batch=True),
            Printing(after_batch=True),
            CheckpointNMT(config['saveto'],
                          every_n_batches=config['save_freq'])
        ]
        '''
        # Set up beam search and sampling computation graphs if necessary
        if config['hook_samples'] >= 1 or config['bleu_script'] is not None:
            logger.info("Building sampling model")
            sampling_representation = encoder.apply(
                sampling_input, tensor.ones(sampling_input.shape))
            generated = decoder.generate(
                sampling_input, sampling_representation)
            search_model = Model(generated)
            _, samples = VariableFilter(
                bricks=[decoder.sequence_generator], name="outputs")(
                    ComputationGraph(generated[1]))

        # Add sampling
        if config['hook_samples'] >= 1:
            logger.info("Building sampler")
            extensions.append(
                Sampler(model=search_model, data_stream=tr_stream,
                        hook_samples=config['hook_samples'],
                        every_n_batches=config['sampling_freq'],
                        src_vocab_size=config['src_vocab_size']))

        # Add early stopping based on bleu
        if config['bleu_script'] is not None:
            logger.info("Building bleu validator")
            extensions.append(
                BleuValidator(sampling_input, samples=samples, config=config,
                              model=search_model, data_stream=dev_stream,
                              normalize=config['normalized_bleu'],
                              every_n_batches=config['bleu_val_freq']))
        '''

        # Reload model if necessary
        if config['reload']:
            extensions.append(LoadNMT(config['saveto']))

        # Plot cost in bokeh if necessary
        if use_bokeh and BOKEH_AVAILABLE:
            extensions.append(
                Plot('Cs-En',
                     channels=[['decoder_cost_cost']],
                     after_batch=True))

        # Set up training algorithm
        logger.info("Initializing training algorithm")
        algorithm = GradientDescent(cost=cost,
                                    parameters=cg.parameters,
                                    on_unused_sources='warn',
                                    step_rule=CompositeRule([
                                        StepClipping(config['step_clipping']),
                                        eval(config['step_rule'])()
                                    ]))

        # Initialize main loop
        logger.info("Initializing main loop")
        main_loop = MainLoop(model=training_model,
                             algorithm=algorithm,
                             data_stream=tr_stream,
                             extensions=extensions)

        # Train!
        main_loop.run()

    elif mode == 'translate':

        # Create Theano variables
        logger.info('Creating theano variables')
        source_sentence = tensor.lmatrix('source')
        source_topical_word = tensor.lmatrix('source_topical')

        # Get test set stream
        test_stream = get_dev_stream_with_topicalq(
            config['test_set'], config['src_vocab'], config['src_vocab_size'],
            config['topical_test_set'], config['topical_vocab'],
            config['topical_vocab_size'], config['unk_id'])
        ftrans = open(config['test_set'] + '.trans.out', 'w')

        # Helper utilities
        sutils = SamplingBase()
        unk_idx = config['unk_id']
        src_eos_idx = config['src_vocab_size'] - 1
        trg_eos_idx = config['trg_vocab_size'] - 1

        # Get beam search
        logger.info("Building sampling model")
        topic_embedding = topical_transformer.apply(source_topical_word)
        representation = encoder.apply(source_sentence,
                                       tensor.ones(source_sentence.shape))
        tw_representation = topical_transformer.look_up.apply(
            source_topical_word.T)
        content_embedding = representation[0, :,
                                           (representation.shape[2] / 2):]
        generated = decoder.generate(source_sentence,
                                     representation,
                                     tw_representation,
                                     topical_embedding=topic_embedding,
                                     content_embedding=content_embedding)

        _, samples = VariableFilter(
            bricks=[decoder.sequence_generator], name="outputs")(
                ComputationGraph(generated[1]))  # generated[1] is next_outputs
        beam_search = BeamSearch(samples=samples)

        logger.info("Loading the model..")
        model = Model(generated)
        loader = LoadNMT(config['saveto'])
        loader.set_model_parameters(model, loader.load_parameters())

        # Get target vocabulary
        trg_vocab = _ensure_special_tokens(pickle.load(
            open(config['trg_vocab'], 'rb')),
                                           bos_idx=0,
                                           eos_idx=trg_eos_idx,
                                           unk_idx=unk_idx)
        trg_ivocab = {v: k for k, v in trg_vocab.items()}

        logger.info("Started translation: ")
        total_cost = 0.0

        for i, line in enumerate(test_stream.get_epoch_iterator()):

            seq = sutils._oov_to_unk(line[0], config['src_vocab_size'],
                                     unk_idx)
            seq2 = line[1]
            input_ = numpy.tile(seq, (config['beam_size'], 1))
            input_topical = numpy.tile(seq2, (config['beam_size'], 1))

            # draw sample, checking to ensure we don't get an empty string back
            trans, costs = \
                beam_search.search(
                    input_values={source_sentence: input_,source_topical_word:input_topical},
                    max_length=10*len(seq), eol_symbol=src_eos_idx,
                    ignore_first_eol=True)
            '''
            # normalize costs according to the sequence lengths
            if config['normalized_bleu']:
                lengths = numpy.array([len(s) for s in trans])
                costs = costs / lengths
            '''
            #best = numpy.argsort(costs)[0]
            best = numpy.argsort(costs)[0:config['beam_size']]
            for b in best:
                try:
                    total_cost += costs[b]
                    trans_out = trans[b]

                    # convert idx to words
                    trans_out = sutils._idx_to_word(trans_out, trg_ivocab)

                except ValueError:
                    logger.info(
                        "Can NOT find a translation for line: {}".format(i +
                                                                         1))
                    trans_out = '<UNK>'

                print(trans_out, file=ftrans)

            if i != 0 and i % 100 == 0:
                logger.info("Translated {} lines of test set...".format(i))

        logger.info("Total cost of the test: {}".format(total_cost))
        ftrans.close()
    elif mode == 'rerank':
        # Create Theano variables
        ftrans = open(config['val_set'] + '.scores.out', 'w')
        logger.info('Creating theano variables')
        source_sentence = tensor.lmatrix('source')
        source_sentence_mask = tensor.matrix('source_mask')
        target_sentence = tensor.lmatrix('target')
        target_sentence_mask = tensor.matrix('target_mask')

        config['src_data'] = config['val_set']
        config['trg_data'] = config['val_set_grndtruth']
        config['batch_size'] = 1
        config['sort_k_batches'] = 1
        test_stream = get_tr_stream_unsorted(**config)
        logger.info("Building sampling model")
        representations = encoder.apply(source_sentence, source_sentence_mask)
        costs = decoder.cost(representations, source_sentence_mask,
                             target_sentence, target_sentence_mask)
        logger.info("Loading the model..")
        model = Model(costs)
        loader = LoadNMT(config['saveto'])
        loader.set_model_parameters(model, loader.load_parameters())

        costs_computer = function([
            source_sentence, source_sentence_mask, target_sentence,
            target_sentence_mask
        ], costs)
        iterator = test_stream.get_epoch_iterator()

        scores = []
        for i, (src, src_mask, trg, trg_mask) in enumerate(iterator):
            costs = costs_computer(*[src, src_mask, trg, trg_mask])
            cost = costs.sum()
            print(i, cost)
            scores.append(cost)
            ftrans.write(str(cost) + "\n")
        ftrans.close()
Esempio n. 11
0
    src = T.lmatrix()
    src_mask = T.matrix()
    trg = T.lmatrix()
    trg_mask = T.matrix()

    rng = numpy.random.RandomState(1234)

    enc_dec = EncoderDecoder(rng, **configuration)
    enc_dec.build_trainer(src, src_mask, trg, trg_mask)
    enc_dec.build_sampler()

    if configuration['reload']:
        enc_dec.load()

    sample_search = BeamSearch(enc_dec=enc_dec,
                               configuration=configuration,
                               beam_size=1,
                               maxlen=configuration['seq_len_src'], stochastic=True)
    valid_search = BeamSearch(enc_dec=enc_dec, 
                              configuration=configuration,
                              beam_size=configuration['beam_size'],
                              maxlen=3*configuration['seq_len_src'], stochastic=False)

    sampler = Sampler(sample_search, **configuration)
    bleuvalidator = BleuValidator(valid_search, **configuration)

    # train function
    train_fn = enc_dec.train_fn
    if configuration.get('with_layernorm', False):
        update_fn = enc_dec.update_fn

    # train data
Esempio n. 12
0
        configuration.update(eval(open(args.state).read()))
    logger.info("\nModel options:\n{}".format(pprint.pformat(configuration)))

    enc_dec = EncoderDecoder(**configuration)
    enc_dec.build_sampler()

    if args.model:
        enc_dec.load(path=args.model)
    else:
        enc_dec.load(path=configuration['saveto_best'])

    beam_size = configuration['beam_size']
    if args.beam:
        beam_size = args.beam

    test_search = BeamSearch(enc_dec=enc_dec,
                             configuration=configuration,
                             beam_size=beam_size,
                             maxlen=3 * configuration['seq_len_src'], stochastic=False)
    bleuvalidator = BleuValidator(search_model=test_search,
                                  test_src=args.source,
                                  test_ref=args.target,
                                  **configuration)

    # test data
    ts = get_devtest_stream(data_type='test', input_file=args.source, **configuration)
    test_bleu = bleuvalidator.apply(ts, args.trans, True)

    logger.info('test bleu %.4f' % test_bleu)

Esempio n. 13
0
def main(mode, config, use_bokeh=False):

    # Construct model
    logger.info('Building RNN encoder-decoder')
    encoder = BidirectionalEncoder(
        config['src_vocab_size'], config['enc_embed'], config['enc_nhids'],name='word_encoder')
    decoder = Decoder(vocab_size=config['trg_vocab_size'],
                      embedding_dim=config['dec_embed'],
                      state_dim=config['dec_nhids'],
                      representation_dim=config['enc_nhids'] * 2,
                      match_function=config['match_function'],
                      use_doubly_stochastic=config['use_doubly_stochastic'],
                      lambda_ds=config['lambda_ds'],
                      use_local_attention=config['use_local_attention'],
                      window_size=config['window_size'],
                      use_step_decay_cost=config['use_step_decay_cost'],
                      use_concentration_cost=config['use_concentration_cost'],
                      lambda_ct=config['lambda_ct'],
                      use_stablilizer=config['use_stablilizer'],
                      lambda_st=config['lambda_st'])
    # here attended dim (representation_dim) of decoder is 2*enc_nhinds
    # because the context given by the encoder is a bidirectional context

    if mode == "train":

        # Create Theano variables
        logger.info('Creating theano variables')
        context_sentences=[];
        context_sentence_masks=[];
        for i in range(config['ctx_num']):
            context_sentences.append(tensor.lmatrix('context_'+str(i)));
            context_sentence_masks.append(tensor.matrix('context_'+str(i)+'_mask'));
        source_sentence = tensor.lmatrix('source')
        source_sentence_mask = tensor.matrix('source_mask')
        target_sentence = tensor.lmatrix('target')
        target_sentence_mask = tensor.matrix('target_mask')
        sampling_input = tensor.lmatrix('input')
        dev_source = tensor.lmatrix('dev_source')
        dev_target=tensor.lmatrix('dev_target')

        # Get training and development set streams
        tr_stream = get_tr_stream_withContext(**config)
        dev_stream = get_dev_stream_with_grdTruth(**config)

        # Get cost of the model
        sentence_representations_list=encoder.apply(source_sentence, source_sentence_mask);
        sentence_representations_list=sentence_representations_list.dimshuffle(['x',0,1,2]);
        sentence_masks_list=source_sentence_mask.T.dimshuffle(['x',0,1]);
        for i in range(config['ctx_num']):
            tmp_rep=encoder.apply(context_sentences[i],context_sentence_masks[i]);
            tmp_rep=tmp_rep.dimshuffle(['x',0,1,2]);
            sentence_representations_list=tensor.concatenate([sentence_representations_list,tmp_rep],axis=0);
            sentence_masks_list=tensor.concatenate([sentence_masks_list,context_sentence_masks[i].T.dimshuffle(['x',0,1])],axis=0);


        cost = decoder.cost(sentence_representations_list,
                            sentence_masks_list,
                            target_sentence,
                            target_sentence_mask)

        logger.info('Creating computational graph')
        perplexity = tensor.exp(cost)
        perplexity.name = 'perplexity'
        costs_computer = function(context_sentences+context_sentence_masks+[target_sentence,
                                   target_sentence_mask,
                                   source_sentence,
                                   source_sentence_mask], (perplexity))
        cg = ComputationGraph(cost)

        # Initialize model
        logger.info('Initializing model')
        encoder.weights_init =decoder.weights_init = IsotropicGaussian(
            config['weight_scale'])
        encoder.biases_init =decoder.biases_init = Constant(0)
        encoder.push_initialization_config()
        decoder.push_initialization_config()
        encoder.bidir.prototype.weights_init = Orthogonal()
        decoder.transition.weights_init = Orthogonal()
        encoder.initialize()
        decoder.initialize()

        # apply dropout for regularization
        if config['dropout'] < 1.0:
            # dropout is applied to the output of maxout in ghog
            logger.info('Applying dropout')
            dropout_inputs = [x for x in cg.intermediary_variables
                              if x.name == 'maxout_apply_output']
            cg = apply_dropout(cg, dropout_inputs, config['dropout'])

        # Apply weight noise for regularization
        if config['weight_noise_ff'] > 0.0:
            logger.info('Applying weight noise to ff layers')
            enc_params = Selector(encoder.lookup).get_params().values()
            enc_params += Selector(encoder.fwd_fork).get_params().values()
            enc_params += Selector(encoder.back_fork).get_params().values()
            dec_params = Selector(
                decoder.sequence_generator.readout).get_params().values()
            dec_params += Selector(
                decoder.sequence_generator.fork).get_params().values()
            dec_params += Selector(decoder.state_init).get_params().values()
            cg = apply_noise(
                cg, enc_params+dec_params, config['weight_noise_ff'])


        # Print shapes
        shapes = [param.get_value().shape for param in cg.parameters]
        logger.info("Parameter shapes: ")
        for shape, count in Counter(shapes).most_common():
            logger.info('    {:15}: {}'.format(shape, count))
        logger.info("Total number of parameters: {}".format(len(shapes)))

        # Print parameter names
        enc_dec_param_dict = merge(Selector(encoder).get_parameters(),
                                   Selector(decoder).get_parameters())
        logger.info("Parameter names: ")
        for name, value in enc_dec_param_dict.items():
            logger.info('    {:15}: {}'.format(value.get_value().shape, name))
        logger.info("Total number of parameters: {}"
                    .format(len(enc_dec_param_dict)))


        # Set up training model
        logger.info("Building model")
        training_model = Model(cost)

        # Set extensions
        logger.info("Initializing extensions")
        extensions = [
            FinishAfter(after_n_batches=config['finish_after']),
            TrainingDataMonitoring([perplexity], after_batch=True),
            CheckpointNMT(config['saveto'],
                          config['model_name'],
                          every_n_batches=config['save_freq'])
        ]

        # Set up beam search and sampling computation graphs if necessary
        if config['hook_samples'] >= 1 or config['bleu_script'] is not None:
            logger.info("Building sampling model")
            sampling_representation = encoder.apply(
                sampling_input, tensor.ones(sampling_input.shape))
            generated = decoder.generate(
                sampling_input, sampling_representation)
            search_model = Model(generated)
            _, samples = VariableFilter(
                bricks=[decoder.sequence_generator], name="outputs")(
                    ComputationGraph(generated[1]))

        # Add sampling
        if config['hook_samples'] >= 1:
            logger.info("Building sampler")
            extensions.append(
                Sampler(model=search_model, data_stream=tr_stream,
                        model_name=config['model_name'],
                        hook_samples=config['hook_samples'],
                        every_n_batches=config['sampling_freq'],
                        src_vocab_size=config['src_vocab_size']))

        # Add early stopping based on bleu
        if False:
            logger.info("Building bleu validator")
            extensions.append(
                BleuValidator(sampling_input, samples=samples, config=config,
                              model=search_model, data_stream=dev_stream,
                              normalize=config['normalized_bleu'],
                              every_n_batches=config['bleu_val_freq'],
                              n_best=3,
                              track_n_models=6))

        logger.info("Building perplexity validator")
        extensions.append(
                pplValidation(dev_source,dev_target, config=config,
                        model=costs_computer, data_stream=dev_stream,
                        model_name=config['model_name'],
                        every_n_batches=config['sampling_freq']))


        # Plot cost in bokeh if necessary
        if use_bokeh and BOKEH_AVAILABLE:
            extensions.append(
                Plot('Cs-En', channels=[['decoder_cost_cost']],
                     after_batch=True))

        # Reload model if necessary
        if config['reload']:
            extensions.append(LoadNMT(config['saveto']))

        initial_learning_rate = config['initial_learning_rate']
        log_path = os.path.join(config['saveto'], 'log')
        if config['reload'] and os.path.exists(log_path):
            with open(log_path, 'rb') as source:
                log = cPickle.load(source)
                last = max(log.keys()) - 1
                if 'learning_rate' in log[last]:
                    initial_learning_rate = log[last]['learning_rate']

        # Set up training algorithm
        logger.info("Initializing training algorithm")
        algorithm = GradientDescent(
            cost=cost, parameters=cg.parameters,
            step_rule=CompositeRule([Scale(initial_learning_rate),
                                     StepClipping(config['step_clipping']),
                                     eval(config['step_rule'])()]))

        _learning_rate = algorithm.step_rule.components[0].learning_rate
        if config['learning_rate_decay']:
            extensions.append(
                LearningRateHalver(record_name='validation_cost',
                                   comparator=lambda x, y: x > y,
                                   learning_rate=_learning_rate,
                                   patience_default=3))
        else:
            extensions.append(OldModelRemover(saveto=config['saveto']))

        if config['learning_rate_grow']:
            extensions.append(
                LearningRateDoubler(record_name='validation_cost',
                                    comparator=lambda x, y: x < y,
                                    learning_rate=_learning_rate,
                                    patience_default=3))

        extensions.append(
            SimplePrinting(config['model_name'], after_batch=True))

        # Initialize main loop
        logger.info("Initializing main loop")
        main_loop = MainLoop(
            model=training_model,
            algorithm=algorithm,
            data_stream=tr_stream,
            extensions=extensions
        )

        # Train!
        main_loop.run()

    elif mode == 'ppl':
        # Create Theano variables
        # Create Theano variables
        logger.info('Creating theano variables')
        context_sentences=[];
        context_sentence_masks=[];
        for i in range(config['ctx_num']):
            context_sentences.append(tensor.lmatrix('context_'+str(i)));
            context_sentence_masks.append(tensor.matrix('context_'+str(i)+'_mask'));
        source_sentence = tensor.lmatrix('source')
        source_sentence_mask = tensor.matrix('source_mask')
        target_sentence = tensor.lmatrix('target')
        target_sentence_mask = tensor.matrix('target_mask')

        # Get training and development set streams
        #tr_stream = get_tr_stream_withContext(**config)
        dev_stream = get_dev_stream_withContext_grdTruth(**config)

        # Get cost of the model
        sentence_representations_list=encoder.apply(source_sentence, source_sentence_mask);
        sentence_representations_list=sentence_representations_list.dimshuffle(['x',0,1,2]);
        sentence_masks_list=source_sentence_mask.T.dimshuffle(['x',0,1]);
        for i in range(config['ctx_num']):
            tmp_rep=encoder.apply(context_sentences[i],context_sentence_masks[i]);
            tmp_rep=tmp_rep.dimshuffle(['x',0,1,2]);
            sentence_representations_list=tensor.concatenate([sentence_representations_list,tmp_rep],axis=0);
            sentence_masks_list=tensor.concatenate([sentence_masks_list,context_sentence_masks[i].T.dimshuffle(['x',0,1])],axis=0);


        cost = decoder.cost(sentence_representations_list,
                            sentence_masks_list,
                            target_sentence,
                            target_sentence_mask)

        logger.info('Creating computational graph')
        costs_computer = function(context_sentences+context_sentence_masks+[target_sentence,
                                   target_sentence_mask,
                                   source_sentence,
                                   source_sentence_mask], (cost))


        logger.info("Loading the model..")
        model = Model(cost)
        #loader = LoadNMT(config['saveto'])
        loader = LoadNMT(config['validation_load']);
        loader.set_model_parameters(model, loader.load_parameters_default())
        logger.info("Started Validation: ")

        ts = dev_stream.get_epoch_iterator()
        total_cost = 0.0
        total_tokens=0.0
        #pbar = ProgressBar(max_value=len(ts)).start()#modified
        pbar = ProgressBar(max_value=10000).start();
        for i, (ctx_0,ctx_0_mask,ctx_1,ctx_1_mask,ctx_2,ctx_2_mask,src, src_mask, trg, trg_mask) in enumerate(ts):
            costs  = costs_computer(*[ctx_0,ctx_1,ctx_2,ctx_0_mask,ctx_1_mask,ctx_2_mask,trg, trg_mask,src, src_mask])
            cost = costs.sum()
            total_cost+=cost
            total_tokens+=trg_mask.sum()
            pbar.update(i + 1)
        total_cost/=total_tokens;
        pbar.finish()
        #dev_stream.reset()

        # run afterprocess
        # self.ap.main()
        total_cost=2**total_cost;
        print("Average validation cost: " + str(total_cost));
    elif mode == 'translate':

        logger.info('Creating theano variables')
        context_sentences=[];
        context_sentence_masks=[];
        for i in range(config['ctx_num']):
            context_sentences.append(tensor.lmatrix('context_'+str(i)));
            context_sentence_masks.append(tensor.matrix('context_'+str(i)+'_mask'));
        source_sentence = tensor.lmatrix('source')
        source_sentence_mask = tensor.matrix('source_mask')

        sutils = SamplingBase()
        unk_idx = config['unk_id']
        src_eos_idx = config['src_vocab_size'] - 1
        trg_eos_idx = config['trg_vocab_size'] - 1
        trg_vocab = _ensure_special_tokens(
            cPickle.load(open(config['trg_vocab'], 'rb')), bos_idx=0,
            eos_idx=trg_eos_idx, unk_idx=unk_idx)
        trg_ivocab = {v: k for k, v in trg_vocab.items()}
        config['batch_size'] = 1

        sentence_representations_list=encoder.apply(source_sentence, source_sentence_mask);
        sentence_representations_list=sentence_representations_list.dimshuffle(['x',0,1,2]);
        sentence_masks_list=source_sentence_mask.T.dimshuffle(['x',0,1]);
        for i in range(config['ctx_num']):
            tmp_rep=encoder.apply(context_sentences[i],context_sentence_masks[i]);
            tmp_rep=tmp_rep.dimshuffle(['x',0,1,2]);
            sentence_representations_list=tensor.concatenate([sentence_representations_list,tmp_rep],axis=0);
            sentence_masks_list=tensor.concatenate([sentence_masks_list,context_sentence_masks[i].T.dimshuffle(['x',0,1])],axis=0);
        generated = decoder.generate(sentence_representations_list,sentence_masks_list)
        _, samples = VariableFilter(
            bricks=[decoder.sequence_generator], name="outputs")(
                ComputationGraph(generated[1]))  # generated[1] is next_outputs
        beam_search = BeamSearch(samples=samples)

        logger.info("Loading the model..")
        model = Model(generated)
        #loader = LoadNMT(config['saveto'])
        loader = LoadNMT(config['validation_load']);
        loader.set_model_parameters(model, loader.load_parameters_default())

        logger.info("Started translation: ")
        test_stream = get_dev_stream_withContext(**config)
        ts = test_stream.get_epoch_iterator()
        rts = open(config['val_set_source']).readlines()
        ftrans_original = open(config['val_output_orig'], 'w')
        saved_weights = []
        total_cost = 0.0

        pbar = ProgressBar(max_value=len(rts)).start()
        for i, (line, line_raw) in enumerate(zip(ts, rts)):
            trans_in = line_raw[3].split()
            seqs=[];
            input_=[];
            input_mask=[];
            for j in range(config['ctx_num']+1):
                seqs.append(sutils._oov_to_unk(
                    line[2*j][0], config['src_vocab_size'], unk_idx))
                input_mask.append(numpy.tile(line[2*j+1][0],(config['beam_size'], 1)))
                input_.append(numpy.tile(seqs[j], (config['beam_size'], 1)))
            #v=costs_computer(input_[0]);
            # draw sample, checking to ensure we don't get an empty string back
            trans, costs, attendeds, weights = \
                beam_search.search(
                    input_values={source_sentence: input_[3],source_sentence_mask:input_mask[3],
                                  context_sentences[0]: input_[0],context_sentence_masks[0]:input_mask[0],
                                  context_sentences[1]: input_[1],context_sentence_masks[1]:input_mask[1],
                                  context_sentences[2]: input_[2],context_sentence_masks[2]:input_mask[2]},
                    max_length=3*len(seqs[2]), eol_symbol=trg_eos_idx,
                    ignore_first_eol=True)

            # normalize costs according to the sequence lengths
            if config['normalized_bleu']:
                lengths = numpy.array([len(s) for s in trans])
                costs = costs / lengths

            b = numpy.argsort(costs)[0]
            #best=numpy.argsort(costs)[0:config['beam_size']];
            #for b in best:
            try:
                total_cost += costs[b]
                trans_out = trans[b]
                totalLen=4*len(line[0][0]);
                #weight = weights[b][:, :totalLen]
                weight=weights
                trans_out = sutils._idx_to_word(trans_out, trg_ivocab)
            except ValueError:
                logger.info(
                    "Can NOT find a translation for line: {}".format(i+1))
                trans_out = '<UNK>'
            saved_weights.append(weight)
            print(' '.join(trans_out), file=ftrans_original)
            pbar.update(i + 1)

        pbar.finish()
        logger.info("Total cost of the test: {}".format(total_cost))
        cPickle.dump(saved_weights, open(config['attention_weights'], 'wb'))
        ftrans_original.close()
        ap = afterprocesser(config)
        ap.main()
Esempio n. 14
0
def main(configuration, is_chief=False):

    l1_reg_weight = configuration['l1_reg_weight']
    l2_reg_weight = configuration['l2_reg_weight']
    #  time_steps*nb_samples
    src = K.placeholder(shape=(None, None), dtype='int32')
    src_mask = K.placeholder(shape=(None, None))
    trg = K.placeholder(shape=(None, None), dtype='int32')
    trg_mask = K.placeholder(shape=(None, None))

    # for fast training of new parameters
    ite = K.placeholder(ndim=0)

    enc_dec = EncoderDecoder(**configuration)

    softmax_output_num_sampled = configuration['softmax_output_num_sampled']

    enc_dec.build_trainer(
        src,
        src_mask,
        trg,
        trg_mask,
        ite,
        l1_reg_weight=l1_reg_weight,
        l2_reg_weight=l2_reg_weight,
        softmax_output_num_sampled=softmax_output_num_sampled)

    enc_dec.build_sampler()

    # Chief is responsible for initializing and loading model states

    if is_chief:
        init_op = tf.initialize_all_variables()
        init_fn = K.function(inputs=[], outputs=[init_op])
        init_fn([])

        if configuration['reload']:
            enc_dec.load()

    sample_search = BeamSearch(enc_dec=enc_dec,
                               configuration=configuration,
                               beam_size=1,
                               maxlen=configuration['seq_len_src'],
                               stochastic=True)

    valid_search = BeamSearch(enc_dec=enc_dec,
                              configuration=configuration,
                              beam_size=configuration['beam_size'],
                              maxlen=3 * configuration['seq_len_src'],
                              stochastic=False)

    sampler = Sampler(sample_search, **configuration)
    bleuvalidator = BleuValidator(valid_search, **configuration)

    # train function
    train_fn = enc_dec.train_fn

    if configuration['with_reconstruction'] and configuration[
            'with_fast_training']:
        fast_train_fn = enc_dec.fast_train_fn

    # train data
    ds = DStream(**configuration)

    # valid data
    vs = get_devtest_stream(data_type='valid',
                            input_file=None,
                            **configuration)

    iters = args.start
    valid_bleu_best = -1
    epoch_best = -1
    iters_best = -1
    max_epochs = configuration['finish_after']

    # TODO: use global iter and only the chief can save the model
    for epoch in range(max_epochs):
        for x, x_mask, y, y_mask in ds.get_iterator():
            last_time = time.time()
            if configuration['with_reconstruction'] and configuration[
                    'with_fast_training'] and iters < configuration[
                        'fast_training_iterations']:
                if configuration['fix_base_parameters'] and not configuration[
                        'with_tied_weights']:
                    tc = fast_train_fn([x.T, x_mask.T, y.T, y_mask.T])
                else:
                    tc = fast_train_fn([x.T, x_mask.T, y.T, y_mask.T, iters])
            else:
                tc = train_fn([x.T, x_mask.T, y.T, y_mask.T])
            cur_time = time.time()
            iters += 1
            logger.info(
                'epoch %d \t updates %d train cost %.4f use time %.4f' %
                (epoch, iters, tc[0], cur_time - last_time))

            if iters % configuration['save_freq'] == 0:
                enc_dec.save()

            if iters % configuration['sample_freq'] == 0:
                sampler.apply(x, y)

            if iters < configuration['val_burn_in']:
                continue

            if (iters <= configuration['val_burn_in_fine'] and iters % configuration['valid_freq'] == 0) \
               or (iters > configuration['val_burn_in_fine'] and iters % configuration['valid_freq_fine'] == 0):
                valid_bleu = bleuvalidator.apply(vs,
                                                 configuration['valid_out'])
                os.system('mkdir -p results/%d' % iters)
                os.system('mv %s* %s results/%d' %
                          (configuration['valid_out'], configuration['saveto'],
                           iters))
                logger.info(
                    'valid_test \t epoch %d \t updates %d valid_bleu %.4f' %
                    (epoch, iters, valid_bleu))
                if valid_bleu > valid_bleu_best:
                    valid_bleu_best = valid_bleu
                    epoch_best = epoch
                    iters_best = iters
                    enc_dec.save(path=configuration['saveto_best'])

    logger.info('final result: epoch %d \t updates %d valid_bleu_best %.4f' %
                (epoch_best, iters_best, valid_bleu_best))
Esempio n. 15
0
class BleuValidator(SimpleExtension, SamplingBase):

    def __init__(self, source_sentence, samples, model, data_stream,
                 config, n_best=1, track_n_models=1, trg_ivocab=None,
                 patience=10, normalize=True, **kwargs):
        super(BleuValidator, self).__init__(**kwargs)
        self.source_sentence = source_sentence
        self.samples = samples
        self.model = model
        self.data_stream = data_stream
        self.config = config
        self.n_best = n_best
        self.track_n_models = track_n_models
        self.normalize = normalize
        self.patience = patience

        # Helpers
        self.vocab = data_stream.dataset.dictionary
        self.trg_ivocab = trg_ivocab
        self.unk_sym = data_stream.dataset.unk_token
        self.eos_sym = data_stream.dataset.eos_token
        self.unk_idx = self.vocab[self.unk_sym]
        self.eos_idx = self.vocab[self.eos_sym]
        self.src_eos_idx = config['src_vocab_size'] - 1
        self.best_models = []
        self.beam_search = BeamSearch(samples=samples)
        self.multibleu_cmd = ['perl', self.config['bleu_script'],
                              self.config['val_set_target'], '<']
        self.compbleu_cmd = [self.config['bleu_script_1'], 
                             self.config['val_set_target'],
                             self.config['val_output_repl']]
        self.ap = afterprocesser(config)

        # Create saving directory if it does not exist
        if not os.path.exists(self.config['saveto']):
            os.makedirs(self.config['saveto'])

    def do(self, which_callback, *args):

        # Track validation burn in
        if self.main_loop.status['iterations_done'] <= \
                self.config['val_burn_in']:
            return

        # Evaluate and save if necessary
        bleu, cost = self._evaluate_model()
        self._save_model(bleu, cost)
        self._stop()

    def _stop(self):
        def get_last_max(l):
            t = 0
            r = 0
            for i, j in enumerate(l):
                if j >= t:
                    r = i
            return r

    def _evaluate_model(self):

        logger.info("Started Validation: ")

        if not self.trg_ivocab:
            sources = self._get_attr_rec(self.main_loop, 'data_stream')
            trg_vocab = sources.data_streams[1].dataset.dictionary
            self.trg_ivocab = {v: k for k, v in trg_vocab.items()}

        ts = self.data_stream.get_epoch_iterator()
        rts = open(self.config['val_set_source']).readlines()
        ftrans_original = open(self.config['val_output_orig'], 'w')
        saved_weights = []
        total_cost = 0.0

        pbar = ProgressBar(max_value=len(rts)).start()
        for i, (line, line_raw) in enumerate(zip(ts, rts)):
            trans_in = line_raw.split()
            seq = self._oov_to_unk(
                line[0], self.config['src_vocab_size'], self.unk_idx)
            input_ = numpy.tile(seq, (self.config['beam_size'], 1))

            # draw sample, checking to ensure we don't get an empty string back
            trans, costs, attendeds, weights = \
                self.beam_search.search(
                    input_values={self.source_sentence: input_},
                    max_length=3*len(seq), eol_symbol=self.src_eos_idx,
                    ignore_first_eol=True)

            # normalize costs according to the sequence lengths
            if self.normalize:
                lengths = numpy.array([len(s) for s in trans])
                costs = costs / lengths

            best = numpy.argsort(costs)[0]
            try:
                total_cost += costs[best]
                trans_out = trans[best]
                weight = weights[best][:, :len(trans_in)]
                trans_out = self._idx_to_word(trans_out, self.trg_ivocab)
            except ValueError:
                logger.info(
                    "Can NOT find a translation for line: {}".format(i+1))
                trans_out = '<UNK>'

            saved_weights.append(weight)
            print(' '.join(trans_out), file=ftrans_original)
            pbar.update(i + 1)

        pbar.finish()
        ftrans_original.close()
        cPickle.dump(saved_weights, open(self.config['attention_weights'], 'wb'))
        self.data_stream.reset()

        # run afterprocess
        # self.ap.main()

        # calculate bleu
        bleu_subproc = Popen(self.compbleu_cmd, stdout=PIPE)
        while True:
            line = bleu_subproc.stdout.readline()
            if line != '':
                if 'BLEU' in line:
                    stdout = line
            else:
                break
        bleu_subproc.terminate()
        out_parse = re.match(r'BLEU = [-.0-9]+', stdout)
        assert out_parse is not None

        # extract the score
        bleu_score = float(out_parse.group()[6:]) * 100
        logger.info('BLEU: ' + str(bleu_score))
        self.main_loop.log.current_row['validation_bleu'] = bleu_score
        self.main_loop.log.current_row['validation_cost'] = total_cost

        return bleu_score, total_cost

    def _is_valid_to_save(self, bleu_score):
        if not self.best_models or min(self.best_models,
           key=operator.attrgetter('score')).score < bleu_score:
            return True
        return False

    def _save_model(self, bleu_score, total_cost):
        if self._is_valid_to_save(bleu_score):
            model = ModelInfo(bleu_score, 'bleu', self.config['saveto'])

            # Manage n-best model list first
            if len(self.best_models) >= self.track_n_models:
                old_model = self.best_models[0]
                if old_model.path and os.path.isfile(old_model.path):
                    logger.info("Deleting old model %s" % old_model.path)
                    os.remove(old_model.path)
                self.best_models.remove(old_model)

            self.best_models.append(model)
            self.best_models.sort(key=operator.attrgetter('score'))

            # Save the model here
            s = signal.signal(signal.SIGINT, signal.SIG_IGN)
            logger.info("Saving new model {}".format(model.path))
            self.dump_parameters(self.main_loop, model.path)
            signal.signal(signal.SIGINT, s)

    def dump_parameters(self, main_loop, path):
        params_to_save = main_loop.model.get_parameter_values()
        param_values = {name.replace("/", "-"): param
                        for name, param in params_to_save.items()}
        outfile_path = path + '.' + str(main_loop.status['iterations_done'])
        with open(outfile_path, 'wb') as outfile:
            numpy.savez(outfile, **param_values)
Esempio n. 16
0
class perplexityValidation(SimpleExtension, SamplingBase):
    """Random Sampling from model."""

    def __init__(self,source_sentence,samples, model, data_stream, model_name,config,
                 src_vocab=None, n_best=1, track_n_models=1, trg_ivocab=None,
                 patience=10, normalize=True, **kwargs):
        super(perplexityValidation, self).__init__(**kwargs)
        self.model = model
        self.data_stream = data_stream
        self.model_name = model_name
        self.src_vocab = src_vocab
        self.trg_ivocab = trg_ivocab
        self.is_synced = False
        self.sampling_fn = model.get_theano_function()

        self.source_sentence = source_sentence
        self.samples = samples
        self.config = config
        self.n_best = n_best
        self.normalize = normalize
        self.patience = patience

        # Helpers
        self.vocab = data_stream.dataset.dictionary
        self.trg_ivocab = trg_ivocab
        self.unk_sym = data_stream.dataset.unk_token
        self.eos_sym = data_stream.dataset.eos_token
        self.unk_idx = self.vocab[self.unk_sym]
        self.eos_idx = self.vocab[self.eos_sym]
        self.src_eos_idx = config['src_vocab_size'] - 1
        self.beam_search = BeamSearch(samples=samples)

    def do(self, which_callback, *args):

        print()
        # Evaluate and save if necessary
        cost = self._evaluate_model()
        print("Average validation cost: " + str(cost));

    def _evaluate_model(self):

        logger.info("Started Validation: ")

        if not self.trg_ivocab:
            sources = self._get_attr_rec(self.main_loop, 'data_stream')
            trg_vocab = sources.data_streams[1].dataset.dictionary
            self.trg_ivocab = {v: k for k, v in trg_vocab.items()}

        ts = self.data_stream.get_epoch_iterator()
        ftrans_original = open(self.config['val_output_orig'], 'w')
        total_cost = 0.0

        pbar = ProgressBar(max_value=len(ts)).start()#modified
        for i, line in enumerate(ts):
            seq = self._oov_to_unk(
                line[0], self.config['src_vocab_size'], self.unk_idx)
            input_ = numpy.tile(seq, (self.config['beam_size'], 1))

            # draw sample, checking to ensure we don't get an empty string back
            trans, costs, attendeds, weights = \
                self.beam_search.search(
                    input_values={self.source_sentence: input_},
                    max_length=3*len(seq), eol_symbol=self.src_eos_idx,
                    ignore_first_eol=True)

            # normalize costs according to the sequence lengths
            if self.normalize:
                lengths = numpy.array([len(s) for s in trans])
                costs = costs / lengths

            best = numpy.argsort(costs)[0]
            try:
                total_cost += costs[best]
                trans_out = trans[best]
                trans_out = self._idx_to_word(trans_out, self.trg_ivocab)
            except ValueError:
                logger.info(
                    "Can NOT find a translation for line: {}".format(i+1))
                trans_out = '<UNK>'

            print(' '.join(trans_out), file=ftrans_original)
            pbar.update(i + 1)

        pbar.finish()
        ftrans_original.close()
        self.data_stream.reset()

        # run afterprocess
        # self.ap.main()
        self.main_loop.log.current_row['validation_cost'] = total_cost

        return total_cost
Esempio n. 17
0
def main(mode, config, use_bokeh=False):

    # Construct model
    logger.info('Building RNN encoder-decoder')
    encoder = BidirectionalEncoder(
        config['src_vocab_size'], config['enc_embed'], config['enc_nhids'])
    decoder = Decoder(
        config['trg_vocab_size'], config['dec_embed'], config['dec_nhids'],
        config['enc_nhids'] * 2,config['topical_embedding_dim'])
    topical_transformer=topicalq_transformer(config['topical_vocab_size'],config['topical_embedding_dim'], config['enc_nhids'],config['topical_word_num'],config['batch_size']);

    if mode == "train":

        # Create Theano variables
        logger.info('Creating theano variables')
        source_sentence = tensor.lmatrix('source')
        source_sentence_mask = tensor.matrix('source_mask')
        target_sentence = tensor.lmatrix('target')
        target_sentence_mask = tensor.matrix('target_mask')
        sampling_input = tensor.lmatrix('input')
        source_topical_word=tensor.lmatrix('source_topical')
        source_topical_mask=tensor.matrix('source_topical_mask')

        # Get training and development set streams
        tr_stream = get_tr_stream_with_topicalq(**config)
        dev_stream = get_dev_stream_with_topicalq(**config)
        topic_embedding=topical_transformer.apply(source_topical_word);
        # Get cost of the model
        representation=encoder.apply(source_sentence, source_sentence_mask);
        tw_representation=topical_transformer.look_up.apply(source_topical_word.T);
        content_embedding=representation[0,:,(representation.shape[2]/2):];

        cost = decoder.cost(
            representation,source_sentence_mask,tw_representation,
            source_topical_mask, target_sentence, target_sentence_mask,topic_embedding,content_embedding);

        logger.info('Creating computational graph')
        cg = ComputationGraph(cost)

        # Initialize model
        logger.info('Initializing model')
        encoder.weights_init = decoder.weights_init = IsotropicGaussian(
            config['weight_scale'])
        encoder.biases_init = decoder.biases_init = Constant(0)
        encoder.push_initialization_config()
        decoder.push_initialization_config()
        encoder.bidir.prototype.weights_init = Orthogonal()
        decoder.transition.weights_init = Orthogonal()
        encoder.initialize()
        decoder.initialize()
        topical_transformer.weights_init=IsotropicGaussian(
            config['weight_scale']);
        topical_transformer.biases_init=Constant(0);
        topical_transformer.push_allocation_config();#don't know whether the initialize is for
        topical_transformer.look_up.weights_init=Orthogonal();
        topical_transformer.transformer.weights_init=Orthogonal();
        topical_transformer.initialize();
        word_topical_embedding=cPickle.load(open(config['topical_embeddings'], 'rb'));
        np_word_topical_embedding=numpy.array(word_topical_embedding,dtype='float32');
        topical_transformer.look_up.W.set_value(np_word_topical_embedding);
        topical_transformer.look_up.W.tag.role=[];


        # apply dropout for regularization
        if config['dropout'] < 1.0:
            # dropout is applied to the output of maxout in ghog
            logger.info('Applying dropout')
            dropout_inputs = [x for x in cg.intermediary_variables
                              if x.name == 'maxout_apply_output']
            cg = apply_dropout(cg, dropout_inputs, config['dropout'])

        # Apply weight noise for regularization
        if config['weight_noise_ff'] > 0.0:
            logger.info('Applying weight noise to ff layers')
            enc_params = Selector(encoder.lookup).get_params().values()
            enc_params += Selector(encoder.fwd_fork).get_params().values()
            enc_params += Selector(encoder.back_fork).get_params().values()
            dec_params = Selector(
                decoder.sequence_generator.readout).get_params().values()
            dec_params += Selector(
                decoder.sequence_generator.fork).get_params().values()
            dec_params += Selector(decoder.state_init).get_params().values()
            cg = apply_noise(
                cg, enc_params+dec_params, config['weight_noise_ff'])

        # Print shapes
        shapes = [param.get_value().shape for param in cg.parameters]
        logger.info("Parameter shapes: ")
        for shape, count in Counter(shapes).most_common():
            logger.info('    {:15}: {}'.format(shape, count))
        logger.info("Total number of parameters: {}".format(len(shapes)))

        # Print parameter names
        enc_dec_param_dict = merge(Selector(encoder).get_parameters(),
                                   Selector(decoder).get_parameters())
        logger.info("Parameter names: ")
        for name, value in enc_dec_param_dict.items():
            logger.info('    {:15}: {}'.format(value.get_value().shape, name))
        logger.info("Total number of parameters: {}"
                    .format(len(enc_dec_param_dict)))

        # Set up training model
        logger.info("Building model")
        training_model = Model(cost)

        # Set extensions
        logger.info("Initializing extensions")
        extensions = [
            FinishAfter(after_n_batches=config['finish_after']),
            TrainingDataMonitoring([cost], after_batch=True),
            Printing(after_batch=True),
            CheckpointNMT(config['saveto'],
                          every_n_batches=config['save_freq'])
        ]
        '''
        # Set up beam search and sampling computation graphs if necessary
        if config['hook_samples'] >= 1 or config['bleu_script'] is not None:
            logger.info("Building sampling model")
            sampling_representation = encoder.apply(
                sampling_input, tensor.ones(sampling_input.shape))
            generated = decoder.generate(
                sampling_input, sampling_representation)
            search_model = Model(generated)
            _, samples = VariableFilter(
                bricks=[decoder.sequence_generator], name="outputs")(
                    ComputationGraph(generated[1]))

        # Add sampling
        if config['hook_samples'] >= 1:
            logger.info("Building sampler")
            extensions.append(
                Sampler(model=search_model, data_stream=tr_stream,
                        hook_samples=config['hook_samples'],
                        every_n_batches=config['sampling_freq'],
                        src_vocab_size=config['src_vocab_size']))

        # Add early stopping based on bleu
        if config['bleu_script'] is not None:
            logger.info("Building bleu validator")
            extensions.append(
                BleuValidator(sampling_input, samples=samples, config=config,
                              model=search_model, data_stream=dev_stream,
                              normalize=config['normalized_bleu'],
                              every_n_batches=config['bleu_val_freq']))
        '''

        # Reload model if necessary
        if config['reload']:
            extensions.append(LoadNMT(config['saveto']))

        # Plot cost in bokeh if necessary
        if use_bokeh and BOKEH_AVAILABLE:
            extensions.append(
                Plot('Cs-En', channels=[['decoder_cost_cost']],
                     after_batch=True))

        # Set up training algorithm
        logger.info("Initializing training algorithm")
        algorithm = GradientDescent(
            cost=cost, parameters=cg.parameters,on_unused_sources='warn',
            step_rule=CompositeRule([StepClipping(config['step_clipping']),
                                     eval(config['step_rule'])()])
        )

        # Initialize main loop
        logger.info("Initializing main loop")
        main_loop = MainLoop(
            model=training_model,
            algorithm=algorithm,
            data_stream=tr_stream,
            extensions=extensions
        )

        # Train!
        main_loop.run()

    elif mode == 'translate':

        # Create Theano variables
        logger.info('Creating theano variables')
        source_sentence = tensor.lmatrix('source')
        source_topical_word=tensor.lmatrix('source_topical')

        # Get test set stream
        test_stream = get_dev_stream_with_topicalq(
            config['test_set'], config['src_vocab'],
            config['src_vocab_size'],config['topical_test_set'],config['topical_vocab'],config['topical_vocab_size'],config['unk_id'])
        ftrans = open(config['test_set'] + '.trans.out', 'w')

        # Helper utilities
        sutils = SamplingBase()
        unk_idx = config['unk_id']
        src_eos_idx = config['src_vocab_size'] - 1
        trg_eos_idx = config['trg_vocab_size'] - 1

        # Get beam search
        logger.info("Building sampling model")
        topic_embedding=topical_transformer.apply(source_topical_word);
        representation=encoder.apply(source_sentence, tensor.ones(source_sentence.shape));
        tw_representation=topical_transformer.look_up.apply(source_topical_word.T);
        content_embedding=representation[0,:,(representation.shape[2]/2):];
        generated = decoder.generate(source_sentence,representation, tw_representation,topical_embedding=topic_embedding,content_embedding=content_embedding);


        _, samples = VariableFilter(
            bricks=[decoder.sequence_generator], name="outputs")(
                ComputationGraph(generated[1]))  # generated[1] is next_outputs
        beam_search = BeamSearch(samples=samples)

        logger.info("Loading the model..")
        model = Model(generated)
        loader = LoadNMT(config['saveto'])
        loader.set_model_parameters(model, loader.load_parameters())

        # Get target vocabulary
        trg_vocab = _ensure_special_tokens(
            pickle.load(open(config['trg_vocab'], 'rb')), bos_idx=0,
            eos_idx=trg_eos_idx, unk_idx=unk_idx)
        trg_ivocab = {v: k for k, v in trg_vocab.items()}

        logger.info("Started translation: ")
        total_cost = 0.0

        for i, line in enumerate(test_stream.get_epoch_iterator()):

            seq = sutils._oov_to_unk(
                line[0], config['src_vocab_size'], unk_idx)
            seq2 = line[1];
            input_ = numpy.tile(seq, (config['beam_size'], 1))
            input_topical=numpy.tile(seq2,(config['beam_size'],1))


            # draw sample, checking to ensure we don't get an empty string back
            trans, costs = \
                beam_search.search(
                    input_values={source_sentence: input_,source_topical_word:input_topical},
                    max_length=10*len(seq), eol_symbol=src_eos_idx,
                    ignore_first_eol=True)
            '''
            # normalize costs according to the sequence lengths
            if config['normalized_bleu']:
                lengths = numpy.array([len(s) for s in trans])
                costs = costs / lengths
            '''
            #best = numpy.argsort(costs)[0]
            best=numpy.argsort(costs)[0:config['beam_size']];
            for b in best:
                try:
                    total_cost += costs[b]
                    trans_out = trans[b]

                    # convert idx to words
                    trans_out = sutils._idx_to_word(trans_out, trg_ivocab)

                except ValueError:
                    logger.info(
                        "Can NOT find a translation for line: {}".format(i+1))
                    trans_out = '<UNK>'

                print(trans_out, file=ftrans)

            if i != 0 and i % 100 == 0:
                logger.info(
                    "Translated {} lines of test set...".format(i))

        logger.info("Total cost of the test: {}".format(total_cost))
        ftrans.close()
    elif mode == 'rerank':
        # Create Theano variables
        ftrans = open(config['val_set'] + '.scores.out', 'w')
        logger.info('Creating theano variables')
        source_sentence = tensor.lmatrix('source')
        source_sentence_mask = tensor.matrix('source_mask')
        target_sentence = tensor.lmatrix('target')
        target_sentence_mask = tensor.matrix('target_mask')

        config['src_data']=config['val_set']
        config['trg_data']=config['val_set_grndtruth']
        config['batch_size']=1;
        config['sort_k_batches']=1;
        test_stream=get_tr_stream_unsorted(**config);
        logger.info("Building sampling model")
        representations= encoder.apply(
            source_sentence,  source_sentence_mask)
        costs = decoder.cost(representations, source_sentence_mask,
            target_sentence, target_sentence_mask)
        logger.info("Loading the model..")
        model = Model(costs)
        loader = LoadNMT(config['saveto'])
        loader.set_model_parameters(model, loader.load_parameters())

        costs_computer = function([source_sentence,source_sentence_mask,
                                  target_sentence,
                                  target_sentence_mask],costs)
        iterator = test_stream.get_epoch_iterator()

        scores = []
        for i, (src, src_mask, trg, trg_mask) in enumerate(iterator):
            costs = costs_computer(*[src, src_mask, trg, trg_mask])
            cost = costs.sum()
            print(i, cost)
            scores.append(cost)
            ftrans.write(str(cost)+"\n");
        ftrans.close();
Esempio n. 18
0
def train_lstm(
    dim_proj=10,  # word embeding dimension and LSTM number of hidden units.
    layers=2,  # the number of layers for lstm encoder and decoder
    patience=10,  # Number of epoch to wait before early stop if no progress
    max_epochs=5000,  # The maximum number of epoch to run
    dispFreq=10,  # Display to stdout the training progress every N updates
    decay_c=0.,  # Weight decay for the classifier applied to the U weights.
    begin_valid=0,  ## when begin to evalute the performance on dev and test sets.
    save_on_the_fly=0,
    lrate=10,  # Learning rate for sgd (not used for adadelta and rmsprop)
    encoder='lstm',  # TODO: can be removed must be lstm.
    saveto='lstm_model',  # The best model will be saved there
    validFreq=-1,  #370,  # Compute the validation error after this number of update.
    saveFreq=1110,  # Save the parameters after every saveFreq updates
    maxlen=100,  # Sequence longer then this get ignored
    batch_size=3,  # The batch size during training.
    valid_batch_size=64,  # The batch size used for validation/test set.
    dataset='database',

    # Parameter for extra option
    noise_std=0.,
    use_dropout=False,  #True,  # if False slightly faster, but worst test error
    # This frequently need a bigger model.
    reload_model=None,  # Path to a saved model we want to start from.
    test_size=-1,  # If >0, we keep only this number of test example.
    src_file="3.zh",
    trg_file="3.en",
    align_file="3.align",
    src_dict='src.dict.pkl',
    trg_dict='trg.dict.pkl',
    dev_file='',
    dev_ref='',
    dev_xml='',
    tst_file='',
    tst_ref='',
    tst_xml='',
    reverse_src=True,
    reverse_trg=False,
    bi_train=False,
    bi_reg=1.0,
    beamsize=12,
):
    # Model options
    model_options = locals().copy()
    print 'Loading data'
    train, n_words_x, n_words_y = preprocess_data(src_file, trg_file)

    srcdict, srcdict_rev = load_dict(src_dict, True)
    trgdict, trgdict_rev = load_dict(trg_dict, True)

    print 'the tot number of examples is %d ' % len(train[0])
    if reverse_src:
        for i in xrange(len(train[0])):
            train[0][i] = train[0][i][::-1]

    numpy.savez('train_data.npz', train)
    print "n_words_x", n_words_x
    print "n_words_y", n_words_y

    #_, valid, test = split_train(train)
    valid, test = train, train
    model_options['n_words_x'] = n_words_x
    model_options['n_words_y'] = n_words_y

    print "model options", model_options
    logger.debug("Building model")

    ## generate the target from left to right
    enc_dec = EncoderDecoder(model_options, prefix='l2r')

    if reload_model:
        enc_dec.load(saveto + '_f.npz')

    inputs, cost, use_noise = enc_dec.BuildEncDec()

    rev_model_options = locals().copy()
    rev_model_options['n_words_x'] = n_words_x
    rev_model_options['n_words_y'] = n_words_y
    rev_model_options['reverse_src'] = False
    rev_model_options['reverse_trg'] = True

    print 'building the right to left model'
    ## generate the target from right to left, and it doesnot share the params with enc_dec
    rev_enc_dec = EncoderDecoder(rev_model_options, prefix='r2l')
    if reload_model:
        rev_enc_dec.load(saveto + '_r.npz')

    rev_inputs, rev_cost, rev_use_noise = rev_enc_dec.BuildEncDec()

    inputs = inputs + rev_inputs
    cost = bi_reg * cost + rev_cost
    params = enc_dec.params + rev_enc_dec.params

    f_bi_cost = theano.function(inputs, cost, on_unused_input='ignore')
    print 'total params: ', params

    beamsearch = BeamSearch(enc_dec, beamsize, trgdict, trgdict_rev, srcdict,
                            srcdict_rev)
    beamsearch_rev = BeamSearch(rev_enc_dec, beamsize, trgdict, trgdict_rev,
                                srcdict, srcdict_rev)
    print 'the initialized params'
    print_tparams(beamsearch.enc_dec.params)
    if decay_c > 0.:
        decay_c = theano.shared(numpy_floatX(decay_c), name='decay_c')
        weight_decay = 0.
        weight_decay += (enc_dec.decoder.U**2).sum()
        weight_decay *= decay_c
        cost += weight_decay

    f_cost = enc_dec.f_cost  #theano.function(inputs, cost, name='f_cost',on_unused_input='ignore')
    f_pred = theano.function(inputs, enc_dec.decoder.argpred, \
                             name='f_pred',on_unused_input='ignore')

    #print enc_dec.params
    lr = tensor.scalar(name='lr')
    trainer = Trainer()

    logger.debug("Compiling grads")
    f_grad_shared, f_update = trainer.SetupTainer(lr, inputs, cost, params)
    logger.debug("Compiling grads over")

    kf_valid = get_minibatches_idx(len(valid[0]), valid_batch_size)
    kf_test = get_minibatches_idx(len(test[0]), valid_batch_size)

    print "%d train examples" % len(train[0])
    print "%d valid examples" % len(valid[0])
    print "%d test examples" % len(test[0])

    history_errs = []
    best_p = None
    bad_count = 0

    kf_fix = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)

    iter_epoches = len(train[0]) / batch_size
    if validFreq == -1:
        validFreq = 370 if iter_epoches < 370 else iter_epoches
    if saveFreq == -1:
        saveFreq = len(train[0]) / batch_size
    # x is a list wrt a sentence
    rev_seq_f = lambda x: x[::-1]
    # x is a list wrt a sentence, for generating y from r2l
    rev_seq_f_eos = lambda x: x[:-1][::-1] + [x[-1]]
    # xx is a minibatch, i.e. a list of list
    reverse_lst_f = lambda xx: map(rev_seq_f, xx)
    reverse_lst_f_eos = lambda xx: map(rev_seq_f_eos, xx)

    print "validFreq, saveFreq", validFreq, saveFreq
    uidx = 0  # the number of update done
    estop = False  # early stop
    start_time = time.clock()
    model_f_iter, model_r_iter = None, None
    try:
        for eidx in xrange(max_epochs):
            tot_cost = 0
            n_samples = 0

            model_f_iter = saveto + '_f_iter%i.npz' % eidx
            model_r_iter = saveto + '_r_iter%i.npz' % eidx
            if model_options['save_on_the_fly']:
                rev_enc_dec.save(model_r_iter)
                enc_dec.save(model_f_iter)

            # Get new shuffled index for the training set.
            kf = get_minibatches_idx(len(train[0]), batch_size, shuffle=True)
            #kf = kf_fix[:1]
            for _, train_index in kf:
                uidx += 1
                use_noise.set_value(1.)
                # Select the random examples for this minibatch
                xx = [train[0][t] for t in train_index]
                yy = [train[1][t] for t in train_index]
                #align = [train[2][t] for t in train_index]
                #label = [train[3][t] for t in train_index]

                x, x_mask, y, y_mask, _, _ = \
                    prepare_reorderdata_minibatch(xx, yy)
                n_samples += x.shape[1]

                r_x, r_x_m, r_y, r_y_m, _, _ = \
                    prepare_reorderdata_minibatch(reverse_lst_f(xx),reverse_lst_f_eos(yy))

                f_inputs = [x, x_mask, y, y_mask, r_x, r_x_m, r_y, r_y_m]
                b_time = time.time()
                cost = 0
                cost = f_grad_shared(*f_inputs)
                e_time = time.time()
                f_update(lrate)
                tot_cost += cost
                ## update after testing
                #f_update(lrate)

                if numpy.isnan(cost) or numpy.isinf(cost):
                    print 'NaN detected'
                    return 1., 1., 1.
                if numpy.mod(uidx, dispFreq) == 0:
                    #, gold_prob#'cost_bak', cost_bak
                    print 'Epoch ', eidx, 'Update ', uidx, 'Cost ', cost
                    print '----------------l2r'
                    beamsearch.minibatch(x, x_mask, y, y_mask)
                    print '----------------r2l'
                    beamsearch_rev.minibatch(r_x, r_x_m, r_y, r_y_m)
                    logger.debug(
                        "Current speed (including update params) is {} per sentence for {} sentences"
                        .format((e_time - b_time) / len(x[0]), len(x[0])))

                if saveto and numpy.mod(uidx, saveFreq) == 0:
                    print 'Saving...',
                    '''
                    if best_p is not None:
                        params = best_p
                    else:
                        params = enc_dec.unzip()
                    '''
                    enc_dec.save(saveto + '_f.npz', history_errs=history_errs)
                    rev_enc_dec.save(saveto + '_r.npz',
                                     history_errs=history_errs)
                    pkl.dump(model_options, open('%s.pkl' % saveto, 'wb'), -1)
                    print 'Done'

            print 'Seen %d samples, tot_cost %f' % (n_samples, tot_cost)
            if eidx + 1 % 5 == 0:
                sys.stdout.flush()
            if os.path.isfile(dev_file) and eidx > begin_valid:
                os.system('echo evaluate for %d iterations >>eval.txt' % eidx)
                os.system('echo ------------- >>eval.txt')
                test_eval_combine(beamsearch,beamsearch_rev,src_file=dev_file,\
                                  trg_file=dev_ref,src_xml=dev_xml,modelfile=model_f_iter)
                if os.path.isfile(tst_file):
                    test_eval_combine(beamsearch,beamsearch_rev,src_file=tst_file,\
                              trg_file=tst_ref,src_xml=tst_xml,isdev=False,modelfile=model_f_iter)
            if estop:
                break

    except KeyboardInterrupt:
        print "Training interupted"

    return  #train_err, valid_err, test_err