def get_final_results(hyps):
            hyp = hyps[0]
            batch_size = hyp.batch_size
            tar_len = hyp.tar_len

            final_hyp = Hypothesis(batch_size, tar_length, ds.start_id, ds.end_id)
            for i in range(batch_size):
                # rank based on each sample's probs
                sorted_hyps = sort_for_each_hyp(hyps, i)
                hyp = sorted_hyps[0]
                final_hyp.res_ids[i] = hyp.res_ids[i]
                final_hyp.pred_ids[i] = hyp.pred_ids[i]
                final_hyp.probs[i] = hyp.probs[i]
            res = np.asarray(final_hyp.res_ids)
            return res
        def get_new_hyps(all_hyps):
            hyp = all_hyps[0]
            batch_size = hyp.batch_size
            tar_len = hyp.tar_len

            new_hyps = []
            for i in range(beam_size):
                hyp = Hypothesis(batch_size, tar_length, ds.start_id, ds.end_id)
                new_hyps.append(hyp)
            for i in range(batch_size):
                # rank based on each sample's probs
                sorted_hyps = sort_for_each_hyp(all_hyps, i)
                for j in range(beam_size):
                    hyp = sorted_hyps[j]
                    new_hyps[j].res_ids[i] = hyp.res_ids[i]
                    new_hyps[j].pred_ids[i] = hyp.pred_ids[i]
                    new_hyps[j].probs[i] = hyp.probs[i]
            return new_hyps
    def beam_search_test(self):
        beam_size = self.args.beam_size
        ds = DataSet(args)
        test_generator = ds.data_generator('test', 'multi_task')

        def sort_for_each_hyp(hyps, rank_index):
            """Return a list of Hypothesis objects, sorted by descending average log probability"""
            return sorted(hyps, key=lambda h: h.avg_prob[rank_index], reverse=True)

        def get_new_hyps(all_hyps):
            hyp = all_hyps[0]
            batch_size = hyp.batch_size
            tar_len = hyp.tar_len

            new_hyps = []
            for i in range(beam_size):
                hyp = Hypothesis(batch_size, tar_length, ds.start_id, ds.end_id)
                new_hyps.append(hyp)
            for i in range(batch_size):
                # rank based on each sample's probs
                sorted_hyps = sort_for_each_hyp(all_hyps, i)
                for j in range(beam_size):
                    hyp = sorted_hyps[j]
                    new_hyps[j].res_ids[i] = hyp.res_ids[i]
                    new_hyps[j].pred_ids[i] = hyp.pred_ids[i]
                    new_hyps[j].probs[i] = hyp.probs[i]
            return new_hyps

        def update_hyps(all_hyps):
            # all_hyps: beam_size * beam_size current step hyps. 
            new_hyps = get_new_hyps(all_hyps)
            return new_hyps

        def get_final_results(hyps):
            hyp = hyps[0]
            batch_size = hyp.batch_size
            tar_len = hyp.tar_len

            final_hyp = Hypothesis(batch_size, tar_length, ds.start_id, ds.end_id)
            for i in range(batch_size):
                # rank based on each sample's probs
                sorted_hyps = sort_for_each_hyp(hyps, i)
                hyp = sorted_hyps[0]
                final_hyp.res_ids[i] = hyp.res_ids[i]
                final_hyp.pred_ids[i] = hyp.pred_ids[i]
                final_hyp.probs[i] = hyp.probs[i]
            res = np.asarray(final_hyp.res_ids)
            return res

        # load_model
        def compile_new_model():
            _model = self.multi_task_model.get_model()
            _model.compile(
                            optimizer=keras.optimizers.Adam(lr=self.args.lr),
                            loss = {
                                'od1': 'sparse_categorical_crossentropy',
                                'od2': 'sparse_categorical_crossentropy',
                                'od3': 'sparse_categorical_crossentropy',
                            },
                            loss_weights={
                                'od1': 1.,
                                'od2': 1.,
                                'od3': 1.,
                            }
                          )
            return _model


        # load_model
        print('Loading model from: %s' % self.model_path)
        #custom_dict = get_custom_objects()
        #model = load_model(self.model_path, custom_objects=custom_dict)
        model = compile_new_model()
        model.load_weights(self.model_path)

        src_outobj = open(self.src_out_path, 'w')
        pred_outobj = open(self.pred_out_path, 'w')
        tar_outobj = open(self.tar_out_path, 'w')

        for batch_index, ([src_input, tar_input, fact_tar_input, facts_input], \
            [_, _, _]) in enumerate(test_generator):
            if batch_index > (ds.test_sample_num // self.args.batch_size):
                # finish all of the prediction
                break

            print('Current batch: {}/{}. '.format(batch_index, ds.test_sample_num // self.args.batch_size))
            cur_batch_size = tar_input.shape[0]
            tar_length = tar_input.shape[1]
            hyps = []
            for i in range(beam_size):
                hyp = Hypothesis(cur_batch_size, tar_length, ds.start_id, ds.end_id)
                hyps.append(hyp)

            for t in range(1, tar_length):
                # iterate each sample
                # collect all hyps, basically, it's beam_size * beam_size
                all_hyps = []
                for i in range(beam_size):
                    cur_hyp = hyps[i]
                    results = cur_hyp.get_predictable_vars(ds.pad_id)
                    # bs, tar_len, 60000
                    preds, _, _ = model.predict([src_input, np.asarray(results), fact_tar_input, facts_input]) 
                        
                    # get the current step prediction
                    cur_preds = preds[:, t - 1]
                    top_indices = np.argsort(cur_preds)
                    top_indices = top_indices[:, -beam_size:] # the largest one is at the end
                        
                    top_logits = []
                    for sample_index, sample_logits in enumerate(cur_preds):
                        logits = []
                        for beam_index in range(beam_size):
                            logit = sample_logits[top_indices[sample_index][beam_index]]
                            logits.append(logit)
                        top_logits.append(logits)
                    top_logits = np.asarray(top_logits)
                    #print('top_logits: ', top_logits[0])

                    # iterate each new prediction
                    for j in range(beam_size-1, -1, -1):
                        next_hyp = deepcopy(cur_hyp)
                        # bs, 1
                        top_index = top_indices[:, j]
                        top_logit = top_logits[:, j]

                        for bs_idx, _id in enumerate(top_index):
                            next_hyp.res_ids[bs_idx].append(_id)
                            prob = top_logit[bs_idx]
                            next_hyp.probs[bs_idx].append(prob)

                            # get OOV id
                            token = ds.tar_id_tokens.get(int(_id), config.UNK_TOKEN)
                            if token == config.UNK_TOKEN:
                                cur_pred_id = ds.unk_id
                            else:
                                cur_pred_id = _id
                            next_hyp.pred_ids[bs_idx].append(cur_pred_id)

                        all_hyps.append(next_hyp)

                    # if it is the first step, only predict once
                    if t == 1:
                        break
                hyps = update_hyps(all_hyps)
            final_results = get_final_results(hyps)

            def output_results(outputs, outobj):
                for result in outputs:
                    seq = []
                    for _id in result:
                        _id = int(_id)
                        if _id == ds.end_id:
                            break
                        if _id != ds.pad_id and _id != ds.start_id:
                        #if _id != ds.pad_id:
                            seq.append(ds.tar_id_tokens.get(_id, config.UNK_TOKEN))
                    write_line = ' '.join(seq)
                    write_line = write_line + '\n'
                    outobj.write(write_line)
                    outobj.flush()
    
            output_results(results, pred_outobj)
            output_results(src_input, src_outobj)
            output_results(tar_input, tar_outobj)
    
        src_outobj.close()
        pred_outobj.close()
        tar_outobj.close()
        print(self.pred_out_path)