def load_data(config):
    train, dev, test, embeddings, vocab = pickle.load(open(config.data_file, 'rb'))
    trainset, devset, testset = DataSet(train), DataSet(dev), DataSet(test)
    vocab = dict([(v['index'],k) for k,v in vocab.items()])
    trainset.sort(reverse=False)
    train_batches = trainset.get_batches(config.batch_size, config.epochs, rand=False)
    dev_batches = devset.get_batches(config.batch_size, 1, rand=False)
    test_batches = testset.get_batches(config.batch_size, 1, rand=False)
    temp_train = trainset.get_batches(config.batch_size, config.epochs, rand=True)
    dev_batches = [i for i in dev_batches]
    test_batches = [i for i in test_batches]
    temp_train = [i for i in temp_train]
    return len(train), train_batches, dev_batches, test_batches, embeddings, vocab, temp_train
    def decode(query):
        ds = DataSet(args)

        encoder = Encoder(ds.src_vocab_size, self.args.embedding_dim,
                          self.args.hidden_dim, self.args.batch_size)
        decoder = Decoder(ds.tar_vocab_size, self.args.embedding_dim,
                          self.args.hidden_dim, self.args.batch_size)
        optimizer = tf.train.AdamOptimizer()
        checkpoint = tf.contrib.eager.Checkpoint(optimizer=optimizer,
                                                 encoder=encoder,
                                                 decoder=decoder)
        checkpoint.restore(tf.train.latest_checkpoint(self.checkpoints_dir))

        src_pad_length = self.args.src_seq_length
        tar_pad_length = self.args.tar_seq_length
        attn_plot = np.zeros((tar_pad_length, src_pad_length))

        query = tf.keras.backend.expand_dims(query, 0)
        hidden_state = encoder.initial_hidden_state(query.shape[0])
        enc_output, enc_state = encoder(query, hidden_state)
        dec_state = enc_state

        results = []
        dec_input = tf.keras.bakend.expand_dims([ds.start_id], 0)
        for t in range(1, tar_pad_length):
            preds, dec_state, attn_weights = decoder(dec_input, dec_state,
                                                     enc_output)

            # store the attention weights to plot later
            attn_weights = tf.keras.backend.reshape(attn_weights, (-1, ))
            attn_plot[t] = attn_weights.numpy()

            pred_id = tf.keras.backend.argmax(preds[0]).numpy()
            if pred_id == ds.end_id:
                break
            if pred_id != ds.pad_id and pred_id != ds.start_id:
                results.append(ds.tar_id_tokens.get(pred_id, config.UNK_TOKEN))

            dec_input = tf.keras.backend.expand_dims([pred_id], 0)
        return results, attn_plot
Ejemplo n.º 3
0
    def train(self, tag):
        ds = DataSet(self.args)
        _, train_src_ids, train_tar_ids, train_tar_loss_ids, _, train_facts_ids = \
        ds.read_file('train', 
                     max_src_len=self.args.src_seq_length, 
                     max_tar_len=self.args.tar_seq_length, 
                     max_fact_len=self.args.fact_seq_length, 
                     get_fact=True, 
                     get_one_hot=False)
    
        _, valid_src_ids, valid_tar_ids, valid_tar_loss_ids, _, valid_facts_ids = \
        ds.read_file('valid', 
                     max_src_len=self.args.src_seq_length, 
                     max_tar_len=self.args.tar_seq_length, 
                     max_fact_len=self.args.fact_seq_length, 
                     get_fact=True, 
                     get_one_hot=False)
    
        
        if tag == 'train':
            model, encoder_model, decoder_model = MemNNModel(args).get_model()
        elif tag == 'retrain':
            custom_dict = get_custom_objects()
            model = load_model(self.model_path, custom_objects=custom_dict, compile=False)
            encoder_model = load_model(self.encoder_model_path)
            decoder_model = load_model(self.decoder_model_path)
        # When using sparse_categorical_crossentropy your labels should be of shape (batch_size, seq_length, 1) instead of simply (batch_size, seq_length).
        opt = tf.keras.optimizers.Adam(lr=0.003, beta_1=0.9, beta_2=0.999, epsilon=None, decay=0.0)
        model.compile(optimizer=opt,
                      loss=tf.keras.losses.sparse_categorical_crossentropy, 
        )
    
        verbose = 1
        earlystopper = EarlyStopping(monitor='val_loss', patience=args.early_stop_patience, verbose=verbose)
        ckpt_name = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
        ckpt_path = os.path.join(self.exp_dir, ckpt_name)
        checkpoint = ModelCheckpoint(ckpt_path, monitor='val_loss', verbose=verbose, save_best_only=True, mode='min')
        lrate = ReduceLROnPlateau(
                                  monitor='val_loss', 
                                  factor=0.5, 
                                  patience=args.lr_decay_patience, 
                                  verbose=verbose, 
                                  mode='auto', 
                                  epsilon=0.0001, 
                                  cooldown=0, 
                                  min_lr=args.lr_min
                                  )
        #callback_list = [earlystopper, checkpoint, lrate]
        callback_list = [lrate]
    
        inp_y=np.expand_dims(np.asarray(train_tar_loss_ids), axis=-1)
        inp_valid_y=np.expand_dims(np.asarray(valid_tar_loss_ids), axis=-1)
        hist = model.fit(x=[
                            np.asarray(train_src_ids),
                            np.asarray(train_facts_ids),
                            np.asarray(train_tar_ids),
                           ],
                         y=inp_y,
                         epochs=args.epochs,
                         batch_size=args.batch_size,
                         callbacks=callback_list, 
                         validation_data=([
                                           np.asarray(valid_src_ids), 
                                           np.asarray(valid_facts_ids),
                                           np.asarray(valid_tar_ids),
                                          ], 
                                          inp_valid_y)
                         )
        with open(self.history_path,'w') as f:
            f.write(str(hist.history))
        # there is something wrong with Keras to save model and load_model. non-serialized problem
        model.save(self.model_path)

        return model
Ejemplo n.º 4
0
    def test(self, model):
        # load_model
        ds = DataSet(args)
        _, test_src_ids, test_tar_ids, test_tar_loss_ids, _, test_facts_ids = \
        ds.read_file('test', 
                     max_src_len=self.args.src_seq_length, 
                     max_tar_len=self.args.tar_seq_length, 
                     max_fact_len=self.args.fact_seq_length, 
                     get_fact=True, 
                     get_one_hot=False)

        src_outobj = open(self.src_out_path, 'w')
        pred_outobj = open(self.pred_out_path, 'w')
        tar_outobj = open(self.tar_out_path, 'w')
    
        def __get_batch():
            batch_src = []
            batch_facts = []
            batch_tar = []
            for (src_input, facts_input, tar_input) in zip(test_src_ids, test_facts_ids, test_tar_ids):
                batch_src.append(src_input)
                batch_facts.append(facts_input)
                batch_tar.append(tar_input)
                if len(batch_src) == self.args.batch_size:
                    res = (np.asarray(batch_src), np.asarray(batch_facts), np.asarray(batch_tar))
                    batch_src = []
                    batch_facts = []
                    batch_tar = []
                    yield res[0], res[1], res[2]
            yield np.asarray(batch_src), np.asarray(batch_facts), np.asarray(batch_tar)

        for (batch, (src_input, facts_input, tar_input)) in enumerate(__get_batch()):
            if batch >= (ds.test_sample_num // self.args.batch_size):
                # finish all of the prediction
                break
            print('Current batch: {}/{}. '.format(batch, len(test_src_ids) // self.args.batch_size))
            cur_batch_size = tar_input.shape[0]
            tar_length = tar_input.shape[1]

            results = []
            results = np.zeros((cur_batch_size, tar_length), dtype='int32')
            results[:, 0] = ds.start_id

            for t in range(1, tar_length):
                preds = model.predict([src_input, facts_input, results]) # shape: (batch_size, tar_length, vocab_size)
                pred_id = np.argmax(preds, axis=-1)
                results[:, t] = pred_id[:, t - 1]

            def output_results(outputs, outobj):
                for result in outputs:
                    seq = []
                    for _id in result:
                        _id = int(_id)
                        if _id == ds.end_id:
                            break
                        if _id != ds.pad_id and _id != ds.start_id:
                            seq.append(ds.tar_id_tokens.get(_id, config.UNK_TOKEN))
                    write_line = ' '.join(seq)
                    write_line = write_line + '\n'
                    outobj.write(write_line)
    
            output_results(results, pred_outobj)
            output_results(src_input, src_outobj)
            output_results(tar_input, tar_outobj)
    
        src_outobj.close()
        pred_outobj.close()
        tar_outobj.close()
    def train(self):
        ds = DataSet(self.args)
        print('*' * 100)
        print('train sample number: ', ds.train_sample_num)
        print('valid sample number: ', ds.valid_sample_num)
        print('test sample number: ', ds.test_sample_num)
        print('*' * 100)

        train_generator = ds.data_generator(
            'train',
            'transformer',
            max_src_len=self.args.src_seq_length,
            max_tar_len=self.args.tar_seq_length,
        )

        valid_generator = ds.data_generator(
            'valid',
            'transformer',
            max_src_len=self.args.src_seq_length,
            max_tar_len=self.args.tar_seq_length,
        )

        def compile_new_model():
            _model = self.transformer_model.get_model(ds.pad_id)
            _model.compile(
                optimizer=keras.optimizers.Adam(lr=self.args.lr),
                loss=keras.losses.sparse_categorical_crossentropy,
            )
            return _model

        if os.path.exists(self.model_path):
            print('Loading model from: %s' % self.model_path)
            custom_dict = get_custom_objects()
            model = load_model(self.model_path, custom_objects=custom_dict)
        else:
            print('Compile new model...')
            model = compile_new_model()

        #model.summary()
        #plot_model(model, to_file='model_structure.png',show_shapes=True)

        verbose = 1
        earlystopper = EarlyStopping(monitor='val_loss',
                                     patience=self.args.early_stop_patience,
                                     verbose=verbose)
        ckpt_name = 'model-ep{epoch:03d}-loss{loss:.3f}-val_loss{val_loss:.3f}.h5'
        ckpt_path = os.path.join(self.exp_dir, ckpt_name)
        checkpoint = ModelCheckpoint(ckpt_path,
                                     monitor='val_loss',
                                     verbose=verbose,
                                     save_best_only=True,
                                     mode='min')
        lrate = keras.callbacks.ReduceLROnPlateau(
            monitor='val_loss',
            factor=0.5,
            patience=self.args.lr_decay_patience,
            verbose=verbose,
            mode='auto',
            min_delta=0.0001,
            cooldown=0,
            min_lr=self.args.lr_min,
        )

        callback_list = [earlystopper, checkpoint, lrate]

        hist = model.fit_generator(
            generator=train_generator,
            steps_per_epoch=(ds.train_sample_num // self.args.batch_size),
            epochs=self.args.epochs,
            callbacks=callback_list,
            validation_data=valid_generator,
            validation_steps=(ds.valid_sample_num // self.args.batch_size),
        )
        with open(self.history_path, 'w') as f:
            f.write(str(hist.history))

        model.save(self.model_path)
    def test(self):
        # load_model
        print('Loading model from: %s' % self.model_path)
        custom_dict = get_custom_objects()
        model = load_model(self.model_path, custom_objects=custom_dict)

        ds = DataSet(args)
        test_generator = ds.data_generator(
            'test',
            'transformer',
            max_src_len=self.args.src_seq_length,
            max_tar_len=self.args.tar_seq_length,
        )

        src_outobj = open(self.src_out_path, 'w')
        pred_outobj = open(self.pred_out_path, 'w')
        tar_outobj = open(self.tar_out_path, 'w')

        for batch, ([src_input,
                     tar_input], tar_loss_input) in enumerate(test_generator):
            if batch > (ds.test_sample_num // self.args.batch_size):
                # finish all of the prediction
                break
            print('Current batch: {}/{}. '.format(
                batch, ds.test_sample_num // self.args.batch_size))
            cur_batch_size = tar_input.shape[0]
            tar_length = tar_input.shape[1]

            results = np.zeros_like(tar_input)
            results[:, 0] = ds.start_id
            for i in range(1, tar_length):
                results[:, i] = ds.pad_id

            for t in range(1, tar_length):
                preds = model.predict([
                    src_input, np.asarray(results)
                ])  # shape: (batch_size, tar_length, vocab_size)
                pred_id = np.argmax(preds, axis=-1)
                results[:, t] = pred_id[:, t - 1]

            def output_results(outputs, outobj):
                for result in outputs:
                    seq = []
                    for _id in result:
                        _id = int(_id)
                        if _id == ds.end_id:
                            break
                        if _id != ds.pad_id and _id != ds.start_id:
                            seq.append(
                                ds.tar_id_tokens.get(_id, config.UNK_TOKEN))
                    write_line = ' '.join(seq)
                    write_line = write_line + '\n'
                    outobj.write(write_line)

            output_results(results, pred_outobj)
            output_results(src_input, src_outobj)
            output_results(tar_input, tar_outobj)

        src_outobj.close()
        pred_outobj.close()
        tar_outobj.close()
Ejemplo n.º 7
0
    def beam_search_test(self):
        beam_size = self.args.beam_size
        ds = DataSet(args)
        test_generator = ds.data_generator('test', 'ted')

        def sort_for_each_hyp(hyps, rank_index):
            """Return a list of Hypothesis objects, sorted by descending average log probability"""
            return sorted(hyps,
                          key=lambda h: h.avg_prob[rank_index],
                          reverse=True)

        def get_new_hyps(all_hyps):
            hyp = all_hyps[0]
            batch_size = hyp.batch_size
            tar_len = hyp.tar_len

            new_hyps = []
            for i in range(beam_size):
                hyp = Hypothesis(batch_size, tar_length, ds.start_id,
                                 ds.end_id)
                new_hyps.append(hyp)
            for i in range(batch_size):
                # rank based on each sample's probs
                sorted_hyps = sort_for_each_hyp(all_hyps, i)
                for j in range(beam_size):
                    hyp = sorted_hyps[j]
                    new_hyps[j].res_ids[i] = hyp.res_ids[i]
                    new_hyps[j].pred_ids[i] = hyp.pred_ids[i]
                    new_hyps[j].probs[i] = hyp.probs[i]
            return new_hyps

        def update_hyps(all_hyps):
            # all_hyps: beam_size * beam_size current step hyps.
            new_hyps = get_new_hyps(all_hyps)
            return new_hyps

        def get_final_results(hyps):
            hyp = hyps[0]
            batch_size = hyp.batch_size
            tar_len = hyp.tar_len

            final_hyp = Hypothesis(batch_size, tar_length, ds.start_id,
                                   ds.end_id)
            for i in range(batch_size):
                # rank based on each sample's probs
                sorted_hyps = sort_for_each_hyp(hyps, i)
                hyp = sorted_hyps[0]
                final_hyp.res_ids[i] = hyp.res_ids[i]
                final_hyp.pred_ids[i] = hyp.pred_ids[i]
                final_hyp.probs[i] = hyp.probs[i]
            res = np.asarray(final_hyp.res_ids)
            return res

        # load_model
        def compile_new_model():
            _model = self.transformer_model.get_model(ds.pad_id)
            _model.compile(
                optimizer=keras.optimizers.Adam(lr=self.args.lr),
                loss=keras.losses.sparse_categorical_crossentropy,
            )
            return _model

        # load_model
        print('Loading model from: %s' % self.model_path)
        custom_dict = get_custom_objects()
        model = load_model(self.model_path, custom_objects=custom_dict)
        #model = compile_new_model()
        #model.load_weights(self.model_path)

        src_outobj = open(self.src_out_path, 'w')
        pred_outobj = open(self.pred_out_path, 'w')
        tar_outobj = open(self.tar_out_path, 'w')

        for batch_index, ([src_input, tar_input, facts_input],
                          tar_loss_input) in enumerate(test_generator):
            if batch_index > (ds.test_sample_num // self.args.batch_size):
                # finish all of the prediction
                break

            print('Current batch: {}/{}. '.format(
                batch_index, ds.test_sample_num // self.args.batch_size))
            cur_batch_size = tar_input.shape[0]
            tar_length = tar_input.shape[1]
            hyps = []
            for i in range(beam_size):
                hyp = Hypothesis(cur_batch_size, tar_length, ds.start_id,
                                 ds.end_id)
                hyps.append(hyp)

            for t in range(1, tar_length):
                # iterate each sample
                # collect all hyps, basically, it's beam_size * beam_size
                all_hyps = []
                for i in range(beam_size):
                    cur_hyp = hyps[i]
                    results = cur_hyp.get_predictable_vars(ds.pad_id)
                    # bs, tar_len, 60000
                    preds = model.predict([
                        src_input, np.asarray(results), facts_input
                    ])  # shape: (batch_size, tar_length, vocab_size)

                    # get the current step prediction
                    cur_preds = preds[:, t - 1]

                    # tar_len, 60000
                    top_indices = np.argsort(cur_preds)
                    top_indices = top_indices[:,
                                              -beam_size:]  # the largest one is at the end

                    top_logits = []
                    for sample_index, sample_logits in enumerate(cur_preds):
                        logits = []
                        for beam_index in range(beam_size):
                            logit = sample_logits[top_indices[sample_index]
                                                  [beam_index]]
                            logits.append(logit)
                        top_logits.append(logits)
                    top_logits = np.asarray(top_logits)

                    # iterate each new prediction
                    for j in range(beam_size - 1, -1, -1):
                        next_hyp = deepcopy(cur_hyp)
                        # bs, 1
                        top_index = top_indices[:, j]
                        top_logit = top_logits[:, j]

                        for bs_idx, _id in enumerate(top_index):
                            next_hyp.res_ids[bs_idx].append(_id)
                            prob = top_logit[bs_idx]
                            next_hyp.probs[bs_idx].append(prob)

                            # get OOV id
                            token = ds.tar_id_tokens.get(
                                int(_id), config.UNK_TOKEN)
                            if token == config.UNK_TOKEN:
                                cur_pred_id = ds.unk_id
                            else:
                                cur_pred_id = _id
                            next_hyp.pred_ids[bs_idx].append(cur_pred_id)

                        all_hyps.append(next_hyp)

                    # if it is the first step, only predict once
                    if t == 1:
                        break
                hyps = update_hyps(all_hyps)
            final_results = get_final_results(hyps)

            def output_results(outputs, outobj):
                for result in outputs:
                    seq = []
                    for _id in result:
                        _id = int(_id)
                        if _id == ds.end_id:
                            break
                        if _id != ds.pad_id and _id != ds.start_id:
                            seq.append(
                                ds.tar_id_tokens.get(_id, config.UNK_TOKEN))
                    write_line = ' '.join(seq)
                    write_line = write_line + '\n'
                    outobj.write(write_line)
                    outobj.flush()

            output_results(results, pred_outobj)
            output_results(src_input, src_outobj)
            output_results(tar_input, tar_outobj)

        src_outobj.close()
        pred_outobj.close()
        tar_outobj.close()

        preds = []
        with open(self.pred_out_path) as f:
            for line in f:
                preds.append(line.strip())
        tgts = []
        with open(self.tar_out_path) as f:
            for line in f:
                tgts.append(line.strip())
        exact_scores, f1_scores = metrics.get_raw_scores(tgts, preds)
        em_f1 = metrics.make_eval_dict(exact_scores, f1_scores)
        print(em_f1)
Ejemplo n.º 8
0
    def test(self):
        ds = DataSet(args)
        test_generator = ds.data_generator('test', 'ted')

        def compile_new_model():
            _model = self.transformer_model.get_model(ds.pad_id)
            _model.compile(
                optimizer=keras.optimizers.Adam(lr=self.args.lr),
                loss=keras.losses.sparse_categorical_crossentropy,
            )
            return _model

        # load_model
        print('Loading model from: %s' % self.model_path)
        custom_dict = get_custom_objects()
        model = load_model(self.model_path, custom_objects=custom_dict)
        #model = compile_new_model()
        #model.load_weights(self.model_path)

        src_outobj = open(self.src_out_path, 'w')
        pred_outobj = open(self.pred_out_path, 'w')
        tar_outobj = open(self.tar_out_path, 'w')

        for batch, ([src_input, tar_input, facts_input],
                    tar_loss_input) in enumerate(test_generator):
            if batch > (ds.test_sample_num // self.args.batch_size):
                # finish all of the prediction
                break
            print('Current batch: {}/{}. '.format(
                batch, ds.test_sample_num // self.args.batch_size))
            cur_batch_size = tar_input.shape[0]
            tar_length = tar_input.shape[1]

            results = np.zeros_like(tar_input)
            results[:, 0] = ds.start_id
            for i in range(1, tar_length):
                results[:, i] = ds.pad_id

            for t in range(1, tar_length):
                preds = model.predict([
                    src_input, np.asarray(results), facts_input
                ])  # shape: (batch_size, tar_length, vocab_size)
                pred_id = np.argmax(preds, axis=-1)
                results[:, t] = pred_id[:, t - 1]

            def output_results(outputs, outobj):
                for result in outputs:
                    seq = []
                    for _id in result:
                        _id = int(_id)
                        if _id == ds.end_id:
                            break
                        if _id != ds.pad_id and _id != ds.start_id:
                            seq.append(
                                ds.tar_id_tokens.get(_id, config.UNK_TOKEN))
                    write_line = ' '.join(seq)
                    write_line = write_line + '\n'
                    outobj.write(write_line)
                    outobj.flush()

            output_results(results, pred_outobj)
            output_results(src_input, src_outobj)
            output_results(tar_input, tar_outobj)

        src_outobj.close()
        pred_outobj.close()
        tar_outobj.close()
        print(self.pred_out_path)
    def train(self, _type):
        ds = DataSet(args)
        _, train_src_ids, train_tar_ids, _, _, _ = \
        ds.read_file('train',
                     max_src_len=self.args.src_seq_length,
                     max_tar_len=self.args.tar_seq_length,
                    )

        dataset = tf.data.Dataset.from_tensor_slices(
            (train_src_ids, train_tar_ids))
        dataset = dataset.batch(self.args.batch_size)
        n_batch = len(train_src_ids) // self.args.batch_size

        _, valid_src_ids, valid_tar_ids, _, _, _ = \
        ds.read_file('valid',
                     max_src_len=self.args.src_seq_length,
                     max_tar_len=self.args.tar_seq_length,
                    )

        valid_dataset = tf.data.Dataset.from_tensor_slices(
            (valid_src_ids, valid_tar_ids))
        valid_dataset = valid_dataset.batch(self.args.batch_size)

        encoder = Encoder(ds.src_vocab_size, self.args.embedding_dim,
                          self.args.hidden_dim, self.args.batch_size)
        decoder = Decoder(ds.tar_vocab_size, self.args.embedding_dim,
                          self.args.hidden_dim, self.args.batch_size)

        optimizer = tf.train.AdamOptimizer()
        checkpoint = tf.contrib.eager.Checkpoint(optimizer=optimizer,
                                                 encoder=encoder,
                                                 decoder=decoder)
        if _type == 'retrain':
            checkpoint.restore(tf.train.latest_checkpoint(
                self.checkpoints_dir))

        min_valid_loss = math.inf
        improve_num = 0
        summary_writer = tf.contrib.summary.create_file_writer(
            self.tensorboard_dir)
        with summary_writer.as_default(
        ), tf.contrib.summary.always_record_summaries():
            for epoch in range(self.args.epochs):
                start = time.time()
                total_loss = 0
                for (batch, (src_input, tar_input)) in enumerate(dataset):
                    loss = 0
                    hidden_state = encoder.initial_hidden_state(
                        src_input.shape[0])
                    with tf.GradientTape() as tape:
                        enc_output, enc_state = encoder(
                            src_input, hidden_state)
                        dec_state = enc_state

                        dec_input = tf.keras.backend.expand_dims(
                            [ds.start_id] * src_input.shape[0], 1)
                        for t in range(1, tar_input.shape[1]):
                            # teacher - forcing. feeding the target as the next input
                            preds, dec_state, _ = decoder(
                                dec_input, dec_state, enc_output)
                            loss += self.__loss_function(
                                tar_input[:, t], preds)
                            dec_input = tf.keras.backend.expand_dims(
                                tar_input[:, t], 1)

                    batch_loss = (loss / int(tar_input.shape[1]))
                    total_loss += batch_loss
                    variables = encoder.variables + decoder.variables
                    gradients = tape.gradient(loss, variables)
                    optimizer.apply_gradients(zip(gradients, variables))

                    if batch % self.args.display_step == 0:
                        print('Epoch {}/{}, Batch {}/{}, Batch Loss {:.4f}'.
                              format(epoch + 1, self.args.epochs, batch,
                                     n_batch, batch_loss.numpy()))
                    tf.contrib.summary.scalar("total_loss", total_loss)

                valid_total_loss = self.valid(valid_dataset, encoder, decoder,
                                              ds.start_id)
                print('Epoch {}, Train Loss {:.4f}, Valid Loss {:.4f}'.format(
                    epoch + 1, total_loss / n_batch, valid_total_loss))
                if valid_total_loss < min_valid_loss:
                    improve_num = 0
                    print('Valid loss improves from {}, to {}'.format(
                        min_valid_loss, valid_total_loss))
                    min_valid_loss = valid_total_loss
                    checkpoint.save(file_prefix=self.checkpoint_prefix)
                elif valid_total_loss >= min_valid_loss:
                    improve_num += 1
                    print('Valid loss did not improve from {}'.format(
                        min_valid_loss))
                    if improve_num >= self.args.early_stop_patience:
                        break
                print('Time taken for epoch {}: {} sec \n'.format(
                    epoch + 1,
                    time.time() - start))
        checkpoint.save(file_prefix=self.checkpoint_prefix)
    def test(self):
        ds = DataSet(args)
        indexes, test_src_ids, test_tar_ids, _, _, _ = \
        ds.read_file('test',
                     max_src_len=self.args.src_seq_length,
                     max_tar_len=self.args.tar_seq_length,
                    )

        dataset = tf.data.Dataset.from_tensor_slices(
            (indexes, test_src_ids, test_tar_ids))
        dataset = dataset.batch(self.args.batch_size)
        n_batch = len(test_src_ids) // self.args.batch_size
        print('*' * 100)
        print('Test set size: %d' % len(test_src_ids))
        print('*' * 100)

        encoder = Encoder(ds.src_vocab_size, self.args.embedding_dim,
                          self.args.hidden_dim, self.args.batch_size)
        decoder = Decoder(ds.tar_vocab_size, self.args.embedding_dim,
                          self.args.hidden_dim, self.args.batch_size)
        optimizer = tf.train.AdamOptimizer()
        checkpoint = tf.contrib.eager.Checkpoint(optimizer=optimizer,
                                                 encoder=encoder,
                                                 decoder=decoder)
        checkpoint.restore(tf.train.latest_checkpoint(self.checkpoints_dir))

        src_outobj = open(self.src_out_path, 'w')
        pred_outobj = open(self.pred_out_path, 'w')
        tar_outobj = open(self.tar_out_path, 'w')

        for (batch, (index_input, src_input, tar_input)) in enumerate(dataset):
            print('Current batch {}/{}. '.format(batch, n_batch))
            hidden_state = encoder.initial_hidden_state(src_input.shape[0])
            enc_output, enc_state = encoder(src_input, hidden_state)
            dec_state = enc_state

            results = np.zeros((tar_input.shape[0], tar_input.shape[1]))
            dec_input = tf.keras.backend.expand_dims([ds.start_id] *
                                                     src_input.shape[0], 1)
            for t in range(1, tar_input.shape[1]):
                preds, dec_state, _ = decoder(dec_input, dec_state, enc_output)
                pred_id = tf.keras.backend.argmax(preds, axis=1).numpy()
                results[:, t] = pred_id
                dec_input = tf.reshape(pred_id, (-1, 1))

            def output_results(outputs, outobj, is_output_index=False):
                for idx, result in enumerate(outputs):
                    seq = []
                    for _id in result:
                        _id = int(_id)
                        if _id == ds.end_id:
                            break
                        if _id != ds.pad_id and _id != ds.start_id:
                            seq.append(
                                ds.tar_id_tokens.get(_id, config.UNK_TOKEN))
                    write_line = ' '.join(seq)

                    write_line = write_line + '\n'
                    outobj.write(write_line)

            output_results(results, pred_outobj)
            output_results(src_input, src_outobj)
            output_results(tar_input, tar_outobj)

        src_outobj.close()
        pred_outobj.close()
        tar_outobj.close()