예제 #1
0
def extract_minimal_nlu(dialogue_manager):

    turn = dialogue_manager.turn

    if 'trans' not in turn:
        user_utterance = turn['user']['utt']
    else:
        user_utterance = turn['trans']

    if utils.remove_punctuation(user_utterance) in ['yes', 'yeah', 'okay']:
        return {'sem': ['acknowledgment']}
    elif utils.remove_punctuation(user_utterance) in ['no']:
        return {'sem': ['reject']}
    elif utils.remove_punctuation(user_utterance) in ['uav']:
        #forcing nlu for the uav case
        return {
            'd_act': ['instruction'],
            'robot':
            rig_utils.parse_sentence_for_robots(
                user_utterance, dialogue_manager.rig_db.dynamic_mission_db)
        }
    elif turn['user']['nlu'] == 'no_nlu':
        if DEBUG:
            utils.print_dict(turn)
        logging.error('No NLU input')
        input()
    else:
        nlu_dacts = extract_user_dialogue_acts(turn['user']['nlu'])
        if nlu_dacts == ['ACKNOWLEDGMENT']:
            return {'sem': ['acknowledgment']}
        else:
            return extract_semantic_frames(turn['user']['nlu'],
                                           dialogue_manager)

    return None
예제 #2
0
 def on_data(self, data):
     if 'results' in data:
         for result in data['results']:
             routing_key = 'asr.data' if result['final'] else 'asr.incremental_data'
             trans, final = self._parse_speech_data(data)
             if DEBUG:
                 utils.print_dict({'transcript': trans, 'final': final})
             asr_pub.send((json.dumps({'transcript': trans, 'final': final}),'{}.mic'.format(routing_key)))
     if 'speaker_labels' in data:
         for result in data['speaker_labels']:
             routing_key = 'asr.speaker_labels' if result['final'] else 'asr.incremental_speaker_labels'
             asr_pub.send((json.dumps(data), '{}.mic'.format(routing_key)))
예제 #3
0
    def get_best_transition(self,last_turn,dialogue_manager):

        if 'user' in last_turn and 'nlu' in last_turn['user']:
            #look for frame semantics
            if 'user_turn' in self.transitions:
                if DEBUG:
                    utils.print_dict(last_turn['user']['nlu'])
                return self._state_high_prob(self.transitions,dialogue_manager,nlu=last_turn['user']['nlu'])
            elif 'system_turn' in self.transitions:
                return self._state_high_prob(self.transitions,dialogue_manager)
        elif 'system_turn' in self.transitions:
            return self._state_high_prob(self.transitions,dialogue_manager)
        else:
            return self._get_highest_prob_all_states()
예제 #4
0
    def get_nlu_transitions(self,all_transitions, nlu):
        '''
        Combines d_acts and frame semantics to find the transitions for a given nlu
        :param all_transitions:
        :param nlu:
        :return:
        '''

        tr_nlu = []
        candidate_nlu = []

        transitions = {}

        if 'd_act' in nlu:
            for d_act in nlu['d_act']:
                if d_act not in tr_nlu:
                    tr_nlu.append(d_act)
                if d_act in all_transitions:
                    candidate_nlu.append(d_act)
                    transitions.update(all_transitions[d_act])

        if 'sem' in nlu:
            for frame in nlu['sem']:
                if frame not in tr_nlu:
                    tr_nlu.append(frame)
                if frame in all_transitions:
                    candidate_nlu.append(frame)
                    transitions.update(all_transitions[frame])

        if '|'.join(tr_nlu) in all_transitions:
            return all_transitions['|'.join(tr_nlu)]
        elif 'other' in all_transitions:
            logging.warning(f"combined transition {'|'.join(tr_nlu)} could not be found")
            utils.print_dict(all_transitions)
            transitions.update(all_transitions['other'])
        elif len(candidate_nlu) == 0:
            logging.debug('No NLU match')
            transitions.update({'request_repeat':1.0})

        return transitions
예제 #5
0
def get_all_utterances():

    all_utt = []

    for d in all_dialogues:
        if all_dialogues[d]['dialogue_id'] in dialogue_subsets['test']:
            #don't use the utterance in the test set to build the bow model
            logger.debug(
                f"skipping {all_dialogues[d]['dialogue_id']} since it belongs to the test set"
            )
            continue
        try:
            all_utt += [
                d['user'].translate(str_trans).lower()
                if 'user' in d else '<SIL>'
                for d in all_dialogues[d]['dialogue']
            ]
        except:
            ### DBG ##
            utils.print_dict(all_dialogues[d]['dialogue'])
            input()

    return all_utt
예제 #6
0
    def _get_highest_prob_all_states(self):
        '''
        In case the nlu was not observable sample from all possible transitions
        :return:
        '''


        sys_trans = {k: v/2 for k,v in self.transitions['system_turn'].items()}

        if DEBUG:
            utils.print_dict(sys_trans)

        usr_nlu_trans = {}
        #normalising system
        for nlu in self.transitions['user_turn']:
            usr_nlu_trans[nlu] = {k: v/len(self.transitions['user_turn']) for k,v in self.transitions['user_turn'][nlu].items()}

        #normalising user
        usr_trans = {}
        for nlu in usr_nlu_trans:
            for tr in usr_nlu_trans[nlu]:
                if tr in usr_trans:
                    usr_trans[tr] += usr_nlu_trans[nlu][tr]
                else:
                    usr_trans[tr] = usr_nlu_trans[nlu][tr]

        usr_trans = {k: v / 2 for k, v in usr_trans.items()}

        #merging
        trans = sys_trans
        for tr in usr_trans:
            if tr in trans:
                trans[tr] += usr_trans[tr]
            else:
                trans[tr] = usr_trans[tr]

        return self._state_high_prob(trans)
예제 #7
0
    def _state_high_prob(self, dict_transitions, dialogue_manager, nlu = None):
        '''
        given a dictionary with transition probabilities returns the state with the highest probability
        :param dict_transitions:
        :return:
        '''

        if nlu is not None:
            transitions = self.get_nlu_transitions(copy(dict_transitions['user_turn']), nlu)
        else:
            transitions = copy(dict_transitions['system_turn'])
            if DEBUG:
                utils.print_dict(transitions)

        for st in list(set().union(dialogue_manager.used_no_transitions,dialogue_manager.subtask_travelled_states)):
            if st in transitions:
                try:
                    del transitions[st]
                    logging.debug(f'{st} removed from list of possible transitions')
                except KeyError:
                    logging.error(f'{st} transition not avaialble in the current state, in subtask {dialogue_manager.dm_manager.subtask}')
                    input()

        # masking transitions
        allowed_transitions = mask_transitions(list(transitions.keys()),
                                               dialogue_manager,
                                               dialogue_manager.nlg_states)

        for st in list(set(list(transitions.keys()))-set(allowed_transitions)):
            del transitions[st]

        if len(transitions) > 0:
            return max(transitions.items(), key=operator.itemgetter(1))[0]
        elif len(allowed_transitions) > 0:
            return random.choice(allowed_transitions)
        else:
            logging.error(f'No transitions for state {self.name}, in subtask {dialogue_manager.subtask}')
예제 #8
0
    def __init__(self, message):

        audio_config = message[0]
        self.routing_ley = message

        self._rate = audio_config['rate']
        self.chunk_size = audio_config['chunk']
        self._num_channels = 1
        self._buff = queue.Queue()
        self.closed = True
        self.start_time = get_current_time()
        self.restart_counter = 0
        self.audio_input = []
        self.last_audio_input = []
        self.result_end_time = 0
        self.is_final_end_time = 0
        self.final_request_end_time = 0
        self.bridging_offset = 0
        self.last_transcript_was_final = False
        self.new_stream = True

        self.address = audio_config['address']

        utils.print_dict(audio_config)
예제 #9
0
    def write_transitions(self, output_dir):

        for st in self.states_nlg:
            state_dict = collections.OrderedDict({
                'name': self.states_nlg[st].name,
                'common': {
                    'formulations': self.states_nlg[st].utterances,
                    'transition_states': {
                        'system_turn':
                        self.states_nlg[st].transitions_system_turn
                    }
                }
            })

            if len(self.states_nlg[st].transitions_user_turn) > 0:
                state_dict['common']['transition_states'][
                    'user_turn'] = self.states_nlg[st].transitions_user_turn

            utils.print_dict(state_dict)

            utils.setup_yaml()
            with open(os.path.join(output_dir, '{}.yaml'.format(st)),
                      'w') as of:
                yaml.dump(state_dict, of)
예제 #10
0
def extract_features_subset(subset):

    for d, dial in enumerate(subset):
        # creating entities dict for current dialogue

        feature_file = f"{os.path.join(output_feat_dir,str(dial))}.hdf5"

        if os.path.isfile(feature_file):
            logger.warning(f"feature file {feature_file} already exists")
            continue

        target_data = np.zeros((max_dialogue_len, 1))
        df = {}

        if 'context_concat' in features_sizes:
            context_vector = rig_utils.create_orca_context_vector(
                rig_db.mission_db, rig_db.environment, skip_common=True)
        else:
            context_vector = rig_utils.create_orca_context_vector(
                rig_db.mission_db, rig_db.environment, gazebo_state=True)

        user_turns = []
        previous_states = []

        if 'convert_dialogue' in args.features:
            dialogue_history = []

        for f in features_sizes:
            df[f] = np.ones((max_dialogue_len, features_sizes[f]))

        for t, turn in enumerate(all_dialogues[dial]['dialogue']):

            if 'action_mask' in turn:
                action_mask_available = True

            try:
                target_data[t] = valid_states.index(turn['current_state'])
            except:
                print(f"{turn['current_state']} not found")
                for da in target_data:
                    print(valid_states[int(da[0])])
                sys.exit()
            feat_values = {}

            if 'user' not in turn:
                turn['user'] = '******'

            user_turns.append(turn['user'])

            if 'wrd_emb' in args.features:
                # where the user turn has some content

                if args.embeddings_file:
                    if emb_type == 'glove':
                        feat_values[
                            'wrd_emb_glove'] = emb_utils.get_utterance_embedding(
                                glove_model, turn['user'], glove_dim, wrd_idx)
                    else:
                        feat_values[
                            'wrd_emb_gn'] = emb_utils.get_utterance_embedding_w2v(
                                emb, turn['user'], emb.vector_size)

            if 'bow' in args.features:
                feat_values['bow'] = bow.transform(
                    [turn['user']]).toarray().reshape(
                        (features_sizes['bow'], ))

            if 'context' in args.features:
                feat_values['context'] = rig_utils.get_orca_context_features(
                    context_vector,
                    turn['time'],
                    turn['situation_db'],
                    turn['current_state'],
                    turn,
                    gazebo_state=True,
                    init_mission_db=rig_db.mission_db)
                context_vector = feat_values['context']

            if 'previous_action' in args.features:
                feat_values['previous_action'] = get_previous_action(
                    all_dialogues[dial]['dialogue'], t - 1)

            if 'api' in args.features:
                logging.info('No api features available for orca')
                continue

            if 'action_mask' in args.features:
                feat_values['action_mask'] = np.zeros((len(valid_states), ))
                mask_indexes = [
                    valid_states.index(a) for a in turn['action_mask']
                ]
                for m in mask_indexes:
                    feat_values['action_mask'][m] = 1

            if 'nlu' in args.features:
                if 'nlu' in turn:
                    if isinstance(turn['nlu'], str):
                        nlu_tracker.get_nlu_vector(json.loads(turn['nlu']))
                    else:
                        nlu_tracker.get_nlu_vector(turn['nlu'])
                else:
                    #returns a vector with zeros
                    nlu_tracker.get_default_vector()
                feat_values['nlu'] = nlu_tracker.nlu_vect

            for f in feat_values:
                #if f == 'context':
                #   print(feat_values[f])
                if isinstance(feat_values[f], list):
                    if len(feat_values[f]) != features_sizes[f]:
                        logging.error(
                            'Feature size does not match for {}'.format(f))
                    df[f][t, :] = np.array(feat_values[f], dtype=float)

                elif isinstance(feat_values[f], dict):
                    if len(feat_values[f]) != features_sizes[f]:
                        logging.error(
                            'Feature size does not match for {}'.format(f))
                        utils.print_dict(feat_values[f])
                        if f == 'context':
                            utils.print_dict(turn['situation_db'])
                        if f == 'nlu':
                            nlu_tracker.get_default_vector()
                            utils.print_dict(nlu_tracker.nlu_vect)
                        input()
                    try:
                        df[f][t, :] = np.array(list(feat_values[f].values()),
                                               dtype=float)
                    except:
                        utils.print_dict(feat_values[f])
                        input(f)
                else:
                    df[f][t, :] = feat_values[f]

            previous_states.append(turn['current_state'])

        logger.info(f'Creating file {feature_file}')
        if DEBUG:
            input()
        hcn_utils.data_to_h5(feature_file, df, target_data)
예제 #11
0
        features_sizes['action_mask'] = len(valid_states)

    if 'previous_action' in args.features:
        features_sizes['previous_action'] = len(valid_states)

    if 'nlu' in args.features:
        nlu_tracker = nlu_utils.NLU_feat(rig_db.robots, rig_db.environment)
        features_sizes['nlu'] = len(nlu_tracker.nlu_vect)
        nlu_debug_file = os.path.join(aux_path, 'nlu_feat.txt')
        if not os.path.isfile(nlu_debug_file):
            with open(nlu_debug_file, 'w') as nlu_fp:
                nlu_fp.write('{}'.format('\n'.join(
                    list(nlu_tracker.nlu_vect.keys()))))

    if DEBUG:
        utils.print_dict(features_sizes)
        input()

    total_feature_size = sum(features_sizes.values())

    logger.info('Complete feature size {}'.format(total_feature_size))

    for subset in dialogue_subsets:
        if len(dialogue_subsets[subset]) > 0:
            features_list = list(features_sizes.keys())
            features_list.sort()
            output_feat_dir = os.path.join(aux_path,
                                           f"{'.'.join(features_list)}",
                                           subset)

            if not os.path.isdir(output_feat_dir):
예제 #12
0
    def print_transitions(self):

        for st in self.states_nlg:
            utils.print_dict(self.states_nlg[st].transitions_user_turn)
            utils.print_dict(self.states_nlg[st].transitions_system_turn)
예제 #13
0
    dact_lm = None

if model_info.max_dial_len == 0:
    logger.error('No dialogue found')
    sys.exit()

logger.info('Loading actions set')
try:
    with open(os.path.join(task_dir, 'orca_action_set.txt'), 'r') as lfp:
        available_actions = pickle.load(lfp)
except:
    with open(os.path.join(task_dir, 'orca_action_set.txt'), 'r') as lfp:
        available_actions = lfp.read().splitlines()

if DEBUG:
    utils.print_dict(available_actions)
    input()

if args.features is not None:
    model_info.feature_set = args.features
else:
    # jchiyah fix
    for fragment in args.datadir.split(os.sep):
        if '.' in fragment:
            model_info.feature_set = fragment.split('.')
            args.features = model_info.feature_set
            break
    models_dir = os.path.join(args.datadir, model_info.config)

if args.test_dir is None:
    args.test_dir = task_dir
예제 #14
0
def extract_semantic_frames(nlu, dialogue_manager):

    nlu_dict = {'d_act': [], 'sem': []}

    for d_act in nlu['dialogue_acts']:
        nlu_dict['d_act'].append(d_act['dialogue_act'].lower())

    for frame in nlu['frame_semantics']:
        nlu_dict['sem'].append(frame['frame'].lower())
        if frame['frame'] in ['Sending', 'Motion', 'Being_located']:
            if 'frame_elements' in frame:
                theme_tokens = []
                goal_tokens = []
                cotheme_tokens = []
                for elements in frame['frame_elements']:
                    if elements['frame_element'] == 'Theme':
                        theme_tokens += elements['tokens']
                    if elements['frame_element'] == 'Goal':
                        goal_tokens += elements['tokens']
                    if elements['frame_element'] == 'Cotheme':
                        cotheme_tokens += elements['tokens']

                if theme_tokens != []:
                    theme_tokens = utils.intersect_lists(
                        theme_tokens, d_act['tokens'])
                    nlu_dict['robot'] = rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(theme_tokens, nlu['tokens']),
                        dialogue_manager.rig_db.dynamic_mission_db)

                if goal_tokens != []:
                    goal_tokens = utils.intersect_lists(
                        goal_tokens, d_act['tokens'])
                    nlu_dict['location'] = get_location_name(
                        nlu, goal_tokens, dialogue_manager)

                if cotheme_tokens != []:
                    cotheme_tokens = utils.intersect_lists(
                        cotheme_tokens, d_act['tokens'])
                    nlu_dict['robot'] += rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(cotheme_tokens, nlu['tokens']),
                        dialogue_manager.rig_db.dynamic_mission_db)

        if frame['frame'] == 'Putting_out_fire':
            if 'frame_elements' in frame:
                place_tokens = []
                agent_tokens = []
                for elements in frame['frame_elements']:
                    if elements['frame_element'] == 'Place':
                        place_tokens += elements['tokens']
                    if elements['frame_element'] == 'Agent':
                        agent_tokens += elements['tokens']

                if place_tokens != []:
                    place_tokens = utils.intersect_lists(
                        place_tokens, d_act['tokens'])
                    nlu_dict['location'] = get_location_name(
                        nlu, place_tokens, dialogue_manager)

                if agent_tokens != []:
                    agent_tokens = utils.intersect_lists(
                        agent_tokens, d_act['tokens'])
                    nlu_dict['robot'] = rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(agent_tokens, nlu['tokens']),
                        dialogue_manager.rig_db.dynamic_mission_db)

        if frame['frame'] == 'Inspecting':
            if 'frame_elements' in frame:
                ground_tokens = []
                inspector_tokens = []
                for elements in frame['frame_elements']:
                    if elements['frame_element'] == 'Ground':
                        ground_tokens += elements['tokens']
                    if elements['frame_element'] == 'Inspector':
                        inspector_tokens += elements['tokens']

                if ground_tokens != []:
                    ground_tokens = utils.intersect_lists(
                        ground_tokens, d_act['tokens'])
                    nlu_dict['location'] = get_location_name(
                        nlu, ground_tokens, dialogue_manager)

                if inspector_tokens != []:
                    inspector_tokens = utils.intersect_lists(
                        inspector_tokens, d_act['tokens'])
                    nlu_dict['robot'] = rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(inspector_tokens, nlu['tokens']),
                        dialogue_manager.rig_db.dynamic_mission_db)

        if frame['frame'] == 'Perception_active':
            lexical_unit_lemmas = get_lemma_tokens(frame['lexical_unit'],
                                                   nlu['tokens'])
            if 'hear' in lexical_unit_lemmas and 'see' not in lexical_unit_lemmas:
                #hear is always related to somehting that the user did not understand
                continue
            if 'frame_elements' in frame:
                agent_tokens = []
                for elements in frame['frame_elements']:
                    if elements['frame_element'] == 'Perceiver_agentive':
                        agent_tokens += elements['tokens']

                if agent_tokens != []:
                    agent_tokens = utils.intersect_lists(
                        agent_tokens, d_act['tokens'])
                    nlu_dict['robot'] = rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(agent_tokens, nlu['tokens']),
                        dialogue_manager.rig_db.dynamic_mission_db)

        if frame['frame'] == 'Using':
            if 'frame_elements' in frame:
                for elements in frame['frame_elements']:
                    if elements['frame_element'] == 'Instrument':
                        instrument_tokens = utils.intersect_lists(
                            elements['tokens'], d_act['tokens'])
                        nlu_dict[
                            'robot'] = rig_utils.parse_sentence_for_robots(
                                get_surface_tokens(instrument_tokens,
                                                   nlu['tokens']),
                                dialogue_manager.rig_db.dynamic_mission_db)

        if frame['frame'] == 'Emptying':
            if 'frame_elements' in frame:
                place_tokens = []
                for elements in frame['frame_elements']:
                    if elements['frame_element'] == 'Place':
                        place_tokens += elements['tokens']

                if place_tokens != []:
                    place_tokens = utils.intersect_lists(
                        place_tokens, d_act['tokens'])
                    nlu_dict['location'] = get_location_name(
                        nlu, place_tokens, dialogue_manager)

        if frame['frame'] == 'Bringing':
            if 'frame_elements' in frame:
                theme_tokens = []
                for elements in frame['frame_elements']:
                    if elements['frame_element'] == 'Theme':
                        theme_tokens += elements['tokens']

                    if theme_tokens != []:
                        theme_tokens = utils.intersect_lists(
                            theme_tokens, d_act['tokens'])
                        nlu_dict[
                            'robot'] = rig_utils.parse_sentence_for_robots(
                                get_surface_tokens(theme_tokens,
                                                   nlu['tokens']),
                                dialogue_manager.rig_db.dynamic_mission_db)

        if frame['frame'] == 'Being_in_category':
            if 'frame_elements' in frame:
                item_tokens = []
                for elements in frame['frame_elements']:
                    item_tokens += elements['tokens']

                if item_tokens != []:
                    item_tokens = utils.intersect_lists(
                        item_tokens, d_act['tokens'])
                    nlu_dict['robot'] = rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(item_tokens, nlu['tokens']),
                        dialogue_manager.rig_db.dynamic_mission_db)

                    surface_tokens = get_surface_tokens(
                        item_tokens, nlu['tokens'], remove_pron_lemmas=True)

                    if nlu_dict['robot'] != []:
                        # when the robot is mentioned
                        robot_names = rig_utils.parse_sentence_for_robot_name(
                            surface_tokens,
                            dialogue_manager.rig_db.dynamic_mission_db)
                        item = surface_tokens
                        for r in robot_names:
                            item = re.sub(r, '', item)
                        nlu_dict['item'] = item.strip()
                    else:
                        nlu_dict['item'] = surface_tokens

        if frame['frame'] == 'Being_located':
            if 'frame_elements' in frame:
                theme_tokens = []
                for elements in frame['frame_elements']:
                    theme_tokens += elements['tokens']

                if theme_tokens != []:
                    theme_tokens = utils.intersect_lists(
                        theme_tokens, d_act['tokens'])
                    robots = rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(theme_tokens, nlu['tokens']),
                        dialogue_manager.rig_db.dynamic_mission_db)
                    if len(robots) > 0:
                        nlu_dict['robot'] = robots
                    else:
                        nlu_dict['location'] = get_location_name(
                            nlu, theme_tokens, dialogue_manager)

        if frame['frame'] == 'Telling':
            # "[can you] tell me the <robot_info>?" parser
            if 'frame_elements' in frame:
                message_tokens = []
                for elements in frame['frame_elements']:
                    message_tokens += elements['tokens']

                if message_tokens != []:
                    message_tokens = utils.intersect_lists(
                        message_tokens, d_act['tokens'])
                    nlu_dict['robot'] = rig_utils.parse_sentence_for_robots(
                        get_surface_tokens(message_tokens, nlu['tokens']),
                        dialogue_manager.situation_db)
                    surface_tokens = get_surface_tokens(
                        message_tokens, nlu['tokens'], remove_pron_lemmas=True)
                    if nlu_dict['robot'] != []:
                        # when robot is mentioned
                        print(surface_tokens)
                        robot_names = rig_utils.parse_sentence_for_robot_name(
                            surface_tokens, dialogue_manager.situation_db)
                        item = surface_tokens
                        for r in robot_names:
                            item = re.sub(r, '', item)
                        nlu_dict['item'] = item.strip()
                    else:
                        nlu_dict['item'] = surface_tokens

    for d_act in nlu['dialogue_acts']:
        if d_act['dialogue_act'] in ['INFORM', 'INSTRUCTION'
                                     ] and nlu['frame_semantics'] == []:
            content = get_surface_tokens(d_act['tokens'], nlu['tokens'])
            nlu_dict['robot'] = rig_utils.parse_sentence_for_robots(
                content, dialogue_manager.rig_db.dynamic_mission_db)
            nlu_dict['item'] = parse_for_items(content)
        else:
            logging.debug(
                f"dialogue act {nlu_dict['d_act']} not in above listed or empty frame semantic output"
            )
            # avoid printing the whole thing
            if DEBUG:
                utils.print_dict(nlu['dialogue_acts'])
                utils.print_dict(nlu['frame_semantics'])
                print(nlu['sentence'])

    if DEBUG:
        utils.print_dict(nlu_dict)

    return nlu_dict
예제 #15
0
def compute_scores(y_test,
                   x_test,
                   y_pred,
                   dialogue_setting,
                   action_set,
                   tst_files,
                   model_info,
                   mode='test',
                   log_dir=None,
                   action_mask_test=None,
                   lm_dacts=None):
    '''
    Computes scores for the orca data
    :param y_test:
    :param x_test:
    :param y_pred:
    :param rig_db:
    :param action_set:
    :param tst_files:
    :param log_dir:
    :param action_mask_test:
    :return:
    '''

    test = np.argmax(y_test, axis=-1)
    pred = np.argmax(y_pred, axis=-1)

    correct_prediction = []
    turn_acc_updates_merged = []
    mission_success = []
    end_state_success = []
    turns_mission_success = []
    turns_mission_success_succ_d = []
    dialogue_success = []
    correct_outcome = []
    situated_da_success = []
    d_perplex = []
    preplex_diff = []
    collaborative_task_success = []

    actions_test = []
    actions_pred = []

    if model_info.generate_output and mode != 'train':
        generate_output = '{}_{}'.format('.'.join(model_info.feature_set),
                                         model_info.config)
    else:
        generate_output = None

    # loop over dialogues
    for d in range(test.shape[0]):
        # we are to produce entity replacement results
        dialogue_id = tst_files[d].split(os.sep)[-1].split('.')[0]
        dialogue_log = json.load(
            open(
                utils.get_turn_file(dialogue_id,
                                    log_dir=log_dir,
                                    extension=f'shcn.data.json'), 'r'))
        gt_dialogue_success = utils.get_dialogue_success(dialogue_id,
                                                         log_dir=log_dir)
        all_turns = dialogue_log['dialogue']

        if DEBUG:
            input(tst_files[d])

        if 'condition' in dialogue_log and dialogue_log['condition'] != 'mixed':
            condition = dialogue_log['condition']
        else:
            condition = 'user_plan'

        entity_replacement_success = []
        generated_subtask_sequence = set(['inspect'])

        if generate_output:
            generated_dialogue = {
                'condition': condition,
                'dialogue_id': dialogue_log['dialogue_id'],
                'model_info': model_info.__dict__,
                'turns': []
            }

            if not os.path.isdir(
                    os.path.join(log_dir, 'generated', generate_output)):
                os.makedirs(os.path.join(log_dir, 'generated',
                                         generate_output))
            output_file_generated_dialogue = os.path.join(
                log_dir, 'generated', generate_output,
                f"{dialogue_log['dialogue_id']}.json")

            if DEBUG:
                print(output_file_generated_dialogue)

        # getting the templates for each condition
        if condition == 'mixed':
            rig_db = dialogue_setting['user_plan'].rig_db
            interaction_manager = dialogue_setting['user_plan']
        else:
            rig_db = dialogue_setting[condition].rig_db
            interaction_manager = dialogue_setting[condition]

        #reinitialising dialogue
        interaction_manager.reset_dialogue()

        end_generated_dialogue = False  # controlling if dialogue has reached final state
        turns_to_complete_mission_pred = test.shape[
            1]  #max numbers of turns allowed
        turns_to_complete_mission_test = len(
            all_turns)  #end of the actual dialogue

        d_actions_pred = []
        d_actions_test = []

        # loop over turns
        for t in range(test.shape[1]):

            if isinstance(x_test, list):
                mask_vector_test = x_test[1][d][t]
            elif isinstance(x_test, dict):
                for f in x_test:
                    if f.find('action_mask') > -1:
                        break
                mask_vector_test = x_test[f][d][t]
            else:
                mask_vector_test = x_test[d, t, :]

            if not np.all(mask_vector_test == 1):

                if DEBUG:
                    if 'prompt' in all_turns[t]:
                        print(f"S: {all_turns[t]['prompt']}")
                    if 'user' in all_turns[t]:
                        print(f"U: {all_turns[t]['user']}")

                if 'user' in all_turns[t] and generate_output is not None:
                    # gets the input for the user and updates the rig database
                    interaction_manager.turn['user'] = {}
                    interaction_manager.turn['user']['utt'] = all_turns[t][
                        'user']
                    if 'nlu' in all_turns[t]['user']:
                        interaction_manager.get_nlu(
                            all_turns[t]['user']['nlu'])

                try:
                    # sanity checks
                    if action_set[test[d, t]] != all_turns[t]['current_state']:
                        logging.error(
                            'Dialogue in the test vector is {}, but in the turns file is {}'
                            .format(action_set[test[d, t]],
                                    all_turns[t]['current_state']))
                except:
                    print(t)

                if all_turns[t]['current_state'] == 'inform_emergency_solved':
                    turns_to_complete_mission_test = t

                actions_test.append(test[d, t])
                d_actions_test.append(test[d, t])
                actions_pred.append(pred[d, t])
                d_actions_pred.append(pred[d, t])

                predicted_action = action_set[pred[d, t]]
                true_action = action_set[test[d, t]]

                rig_db.dynamic_mission_db['time_left'] = (
                    rig_db.dynamic_mission_db['mission_time'] -
                    all_turns[t]['time'])

                if DEBUG:
                    print(f"P: {predicted_action} [{true_action}]")

                if 'prompt' in all_turns[t]:
                    plain_prompt = utils.remove_punctuation(
                        all_turns[t]['prompt']).lower(
                        )  # cleaning original prompt for comparison
                else:
                    plain_prompt = None

                gen_turn_dict = {
                    'd_act_pred': action_set[pred[d, t]],
                    'd_act_gt': action_set[test[d, t]],
                    'utt_gt': plain_prompt
                }

                if 'user' in all_turns[t]:
                    gen_turn_dict['user'] = all_turns[t]['user']

                # if predicted actions are different then the entity replacement will also fail
                if predicted_action != true_action:
                    entity_replacement_success.append(0)
                    situated_da_success.append(0)
                    gen_turn_dict['entity'] = False
                    gen_turn_dict['situated_da'] = False

                    if predicted_action in dialogue_setting[condition].nlg.status_updates and \
                        true_action in dialogue_setting[condition].nlg.status_updates:
                        turn_acc_updates_merged.append(1)
                    else:
                        turn_acc_updates_merged.append(0)

                    if action_set[pred[d, t]] in dialogue_setting[
                            'user_plan'].nlg.fixed_dialogue_acts:
                        if generate_output:
                            gen_turn_dict['utt_pred'] = action_set[pred[d, t]]
                    else:
                        candidate_utterances = nlg_utils.generate_utt_entities(
                            action_set[pred[d, t]], interaction_manager)

                        if generate_output:
                            if len(candidate_utterances) == 0:
                                if DEBUG:
                                    print(action_set[pred[d, t]])
                                    input(condition)
                                gen_turn_dict['utt_pred'] = None
                            else:
                                gen_turn_dict['utt_pred'] = random.choice(
                                    candidate_utterances)

                # if action predicted is one of those which are not part of the template
                else:
                    turn_acc_updates_merged.append(1)
                    if predicted_action in [
                            'user_turn', 'okay', 'activate.robot',
                            'deactivate.robot', 'holdon2seconds',
                            'gesture_emotion_neutral', 'actionperformed',
                            'repeat', 'BC', 'UNK', 'sorrycanyourepeatthat',
                            'yes', 'no'
                    ]:

                        if generate_output:
                            gen_turn_dict['utt_pred'] = predicted_action
                            gen_turn_dict['entity'] = True

                        entity_replacement_success.append(1)

                    # if action is part of the template checks if the exact utterance can be reconstructed from the templates
                    else:

                        utt_in_template = False
                        candidate_utterances = nlg_utils.generate_utt_entities(
                            predicted_action, interaction_manager)

                        for ca in candidate_utterances:
                            try:
                                ca = utils.remove_punctuation(ca).lower()
                                if ca == plain_prompt:
                                    utt_in_template = True
                                    break
                            except KeyError:
                                logging.warning(f'No prompt found in turn {t}')
                                break

                        if generate_output and len(candidate_utterances) > 0:
                            gen_turn_dict['utt_pred'] = ca

                        if utt_in_template:
                            entity_replacement_success.append(1)
                            gen_turn_dict['entity'] = True
                        else:
                            entity_replacement_success.append(0)
                            gen_turn_dict['entity'] = False
                            if DEBUG:
                                logging.debug(
                                    'error in entity replacement for dialogue act {}'
                                    .format(action_set[pred[d, t]]))
                                utils.print_dict(candidate_utterances)
                                utils.print_dict(all_turns[t])
                                input()

                if t > 0:
                    if all_turns[t]['current_state'] in [
                            'inform_moving', 'inform_robot_eta',
                            'inform_arrival', 'inform_inspection',
                            'inform_damage_inspection_robot',
                            'inform_returning_to_base', 'inform_robot_battery',
                            'inform_robot_progress', 'inform_robot_velocity',
                            'inform_robot_status'
                    ]:
                        if pred[d, t] == test[d, t]:
                            situated_da_success.append(1)
                            gen_turn_dict['situated_da'] = True
                        else:
                            situated_da_success.append(0)
                            gen_turn_dict['situated_da'] = False
                if generate_output:
                    generated_dialogue['turns'].append(gen_turn_dict)

            if action_set[pred[d, t]] == 'inform_emergency_status':
                generated_subtask_sequence.add('extinguish')
                interaction_manager.subtask = 'extinguish'

            if action_set[pred[d, t]] == 'inform_emergency_solved':
                generated_subtask_sequence.add('assess_damage')
                interaction_manager.subtask = 'assess_damage'
                turns_to_complete_mission_pred = t
                if generate_output:
                    generated_dialogue['turns_to_complete'] = t

            #print(generated_subtask_sequence)
            # if one of the two ending states is reached
            if action_set[pred[d,t]] in ['mission_timeout', 'inform_mission_completed'] and \
                    not end_generated_dialogue:
                #check if sequence of sub-dialogues fullfils the task needs
                end_generated_dialogue = True
                #end_state_success.append(1)
                if DEBUG:
                    utils.print_dict(list(generated_subtask_sequence))
                if set(generated_subtask_sequence) == set(
                    ['inspect', 'extinguish', 'assess_damage']):
                    mission_success.append(1)
                    if generate_output:
                        generated_dialogue['mission_success'] = True
                else:
                    mission_success.append(0)
                    if generate_output:
                        generated_dialogue['mission_success'] = False

                if action_set[pred[d,t]] == 'inform_mission_completed' and mission_success[-1] == 1 or\
                        action_set[pred[d,t]] == 'mission_timeout' and mission_success[-1] == 0:
                    end_state_success.append(1)
                    if generate_output:
                        generated_dialogue['end_state'] = True
                else:
                    end_state_success.append(0)
                    if generate_output:
                        generated_dialogue['end_state'] = False
                    #end_state_success.append(1)

        if not end_generated_dialogue:
            if set(generated_subtask_sequence) == set(
                ['inspect', 'extinguish', 'assess_damage']):
                mission_success.append(1)
                if generate_output:
                    generated_dialogue['mission_success'] = True
            else:
                mission_success.append(0)
                if generate_output:
                    generated_dialogue['mission_success'] = False
            # if we reach the end of the dialogue without reaching one of the end states
            # dialogues are considered non-successful
            if generate_output:
                generated_dialogue['end_state'] = False
            end_state_success.append(0)

        if generate_output:
            generated_dialogue[
                'relative_turns_on_task'] = turns_to_complete_mission_pred / turns_to_complete_mission_test

        if gt_dialogue_success is not None:
            if gt_dialogue_success == mission_success[-1]:
                correct_outcome.append(1)
                if generate_output:
                    generated_dialogue['correct_outcome'] = True
                    generated_dialogue[
                        'expected_outcome'] = gt_dialogue_success
                    generated_dialogue['achieved_outcome'] = mission_success[
                        -1]
            else:
                correct_outcome.append(0)
                if generate_output:
                    generated_dialogue['correct_outcome'] = False
                    generated_dialogue[
                        'expected_outcome'] = gt_dialogue_success
                    generated_dialogue['achieved_outcome'] = mission_success[
                        -1]

            if gt_dialogue_success:
                # if the gt dialogue is successful then check if the output of
                # the model is also success. Else give a 0
                if gt_dialogue_success == mission_success[-1]:
                    collaborative_task_success.append(1)
                else:
                    collaborative_task_success.append(0)
                # do not add anything to collaborative_task_success otherwise
        else:
            correct_outcome.append(0)  #no meaning at all in this case

        turns_mission_success.append(turns_to_complete_mission_pred /
                                     turns_to_complete_mission_test)

        if mission_success[-1]:
            turns_mission_success_succ_d.append(
                turns_to_complete_mission_pred /
                turns_to_complete_mission_test)

        if sum(entity_replacement_success) == len(entity_replacement_success):
            if mission_success[-1] == 0:
                logging.error('Mission was successful')
            dialogue_success.append(1)
            if generate_output:
                generated_dialogue['dialogue_entity_success'] = True
        else:
            dialogue_success.append(0)
            if generate_output:
                generated_dialogue['dialogue_entity_success'] = False

        correct_prediction += entity_replacement_success
        #input()

        if generate_output:
            generated_dialogue['turn_accuracy'] = accuracy_score(
                d_actions_test, d_actions_pred)
            generated_dialogue['turn_accuracy_merged_updates'] = sum(
                turn_acc_updates_merged) / len(turn_acc_updates_merged)

        if lm_dacts:
            d_prob_pred = lm_dacts.p(' '.join(
                [t['d_act_pred'] for t in generated_dialogue['turns']]))
            d_perplex.append(utils.perplexity(d_prob_pred))
            if generate_output:
                generated_dialogue['da_preplexity'] = utils.perplexity(
                    d_prob_pred)
            d_prob_test = lm_dacts.p(' '.join(
                [t['current_state'] for t in all_turns]))
            preplex_diff.append(
                utils.perplexity(d_prob_test) - utils.perplexity(d_prob_pred))
            if generate_output:
                generated_dialogue['da_preplexity_diff'] = utils.perplexity(
                    utils.perplexity(d_prob_test) -
                    utils.perplexity(d_prob_pred))
        else:
            d_perplex.append(0.0)
            preplex_diff.append(0.0)

        if generate_output:
            with open(output_file_generated_dialogue, 'w') as dg_fp:
                json.dump(generated_dialogue, dg_fp, indent=2)

    if len(mission_success) != len(tst_files):
        utils.print_dict(all_turns)
        print(len(mission_success), len(end_state_success),
              len(correct_prediction))
        input()

    return accuracy_score(actions_test,actions_pred),\
           sum(correct_prediction)/len(correct_prediction),\
           sum(mission_success)/len(mission_success),\
           sum(dialogue_success)/len(dialogue_success),\
           sum(end_state_success)/len(end_state_success),\
           sum(correct_outcome)/len(correct_outcome),\
           sum(turns_mission_success)/len(turns_mission_success),\
           sum(situated_da_success)/len(situated_da_success),\
           sum(d_perplex)/len(d_perplex),\
           sum(preplex_diff)/len(preplex_diff),\
           sum(turn_acc_updates_merged)/len(turn_acc_updates_merged),\
           turns_mission_success_succ_d,\
           actions_pred,actions_test,\
           sum(collaborative_task_success) / len(collaborative_task_success) if len(collaborative_task_success) > 0 else np.nan