コード例 #1
0
    def __init__(self, domain: JSONLookupDomain, parameters=None):
        self.domain = domain
        self.parameters = parameters or {}

        # cache inform and request slots
        # make sure to copy the list (shallow is sufficient)
        self.inf_slots = sorted(list(domain.get_informable_slots())[:])
        # make sure that primary key is never a constraint
        if self.domain.get_primary_key() in self.inf_slots:
            self.inf_slots.remove(self.domain.get_primary_key())

        # TODO sometimes ask for specific primary key with very small probability (instead of any other constraints?) # pylint: disable=line-too-long

        self.inf_slot_values = {}
        for slot in self.inf_slots:
            self.inf_slot_values[slot] = sorted(
                domain.get_possible_values(slot)[:])
        self.req_slots = sorted(domain.get_requestable_slots()[:])
        # self.req_slots_without_informables = sorted(list(
        #     set(self.req_slots).difference(self.inf_slots)))
        # make sure that primary key is never a request as it is added anyway
        if self.domain.get_primary_key() in self.req_slots:
            self.req_slots.remove(self.domain.get_primary_key())

        self.constraints = []
        self.requests = {}
        self.excluded_inf_slot_values = {}
        self.missing_informs = []
コード例 #2
0
def _create_inform_json(domain: JSONLookupDomain, template: RegexFile):
    inform_regex_json = {}
    for slot in domain.get_informable_slots():
        inform_regex_json[slot] = {}
        for value in domain.get_possible_values(slot):
            inform_act = UserAct(act_type=UserActionType.Inform,
                                 slot=slot,
                                 value=value)
            inform_regex_json[slot][value] = template.create_regex(inform_act)
    return inform_regex_json
コード例 #3
0
ファイル: goal.py プロジェクト: zhiyin121/adviser
    def __init__(self, domain: JSONLookupDomain, parameters=None):
        """
        The class representing a goal, therefore containing requests and constraints.

        Args:
            domain (JSONLookupDomain): The domain for which the goal will be instantiated.
                It will only work within this domain.
            parameters (dict): The parameters for the goal defined by a key=value mapping: 'MinVenues'
                (int) allows to set a minimum number of venues which fulfill the constraints of the goal,
                'MinConstraints' (int) and 'MaxConstraints' (int) set the minimum and maximum amount of
                constraints respectively, 'MinRequests' (int) and 'MaxRequests' (int) set the minimum and
                maximum amount of requests respectively and 'Reachable' (float) allows to specify how many
                (in percent) of all generated goals are definitely fulfillable (i.e. there exists a venue
                for the current goal) or not (doesn't have to be fulfillable). Although the parameter
                'Reachable' equals 1.0 implicitly states that 'MinVenues' equals 1 or more, the
                implementation looks different, is more efficient and takes all goals into consideration
                (since 'Reachable' is a float (percentage of generated goals)). On the other hand, setting
                'MinVenues' to any number bigger than 0 forces every goal to be fulfillable.

        """
        self.domain = domain
        self.parameters = parameters or {}

        # cache inform and request slots
        # make sure to copy the list (shallow is sufficient)
        self.inf_slots = sorted(list(domain.get_informable_slots())[:])
        # make sure that primary key is never a constraint
        if self.domain.get_primary_key() in self.inf_slots:
            self.inf_slots.remove(self.domain.get_primary_key())

        # TODO sometimes ask for specific primary key with very small probability (instead of any other constraints?) # pylint: disable=line-too-long

        self.inf_slot_values = {}
        for slot in self.inf_slots:
            self.inf_slot_values[slot] = sorted(
                domain.get_possible_values(slot)[:])
        self.req_slots = sorted(domain.get_requestable_slots()[:])
        # self.req_slots_without_informables = sorted(list(
        #     set(self.req_slots).difference(self.inf_slots)))
        # make sure that primary key is never a request as it is added anyway
        if self.domain.get_primary_key() in self.req_slots:
            self.req_slots.remove(self.domain.get_primary_key())

        self.constraints = []
        self.requests = {}
        self.excluded_inf_slot_values = {}
        self.missing_informs = []
コード例 #4
0
def create_json_from_template(domain: JSONLookupDomain,
                              template_filename: str):
    template = RegexFile(template_filename, domain)
    domain_name = domain.get_domain_name()
    _write_dict_to_file(_create_request_json(domain, template),
                        f'{domain_name}RequestRules.json')
    _write_dict_to_file(_create_inform_json(domain, template),
                        f'{domain_name}InformRules.json')
コード例 #5
0
def load_superhero_domain():
    """ Try loading the restaurant domain shared by all following tests in this file. """
    domain = JSONLookupDomain('superhero')

    assert domain.db is not None
    assert domain.ontology_json is not None

    return domain
コード例 #6
0
ファイル: run_chat.py プロジェクト: wendywtchang/adviser
def test_domain(domain_name: str, policy_type: str, gui: bool, logger: DiasysLogger,
                language: Language):
    """ Start chat with system.

    Args:
        domain_name (str): name of domain (according to the names in resources/databases)
        policy_type (str): either hdc (handcrafted policy) or dqn (reinforcement learning policy)
        gui (bool): if true, will start a QT GUI session, otherwise the console will be used
                    for interaction
        logger (DiasysLogger): logger for all modules

    .. note::
    
        When using dqn, make sure you have a trained model. You can train a model for the specified
        domain by executing
        
        .. code:: bash
        
            python modules/policy/rl/train_dqnpolicy.py -d domain_name

    """

    # init domain
    domain = JSONLookupDomain(name=domain_name)

    # init modules
    nlu = HandcraftedNLU(domain=domain, logger=logger, language=language)
    bst = HandcraftedBST(domain=domain, logger=logger)

    if policy_type == 'hdc':
        policy = HandcraftedPolicy(domain=domain, logger=logger)
    else:
        policy = DQNPolicy(domain=domain, train_dialogs=1, logger=logger)
        policy.load()

    nlg = HandcraftedNLG(domain=domain, logger=logger, language=language)

    # interaction mode
    if gui:
        input_module = GuiInput(domain, logger=logger)
        output_module = GuiOutput(domain, logger=logger)
    else:
        input_module = ConsoleInput(domain, logger=logger, language=language)
        output_module = ConsoleOutput(domain, logger=logger)

    # construct dialog graph
    ds = DialogSystem(
        input_module,
        nlu,
        bst,
        policy,
        nlg,
        output_module,
        logger=logger)

    # start chat
    ds.eval()
    ds.run_dialog()
コード例 #7
0
def _create_general_json(domain: JSONLookupDomain, template: RegexFile):
    general_regex_json = {}
    for general_act_name in domain.get_discourse_acts():
        if general_act_name in ('none', 'silence'):
            continue
        general_act = UserAct(act_type=UserActionType(general_act_name))
        general_regex_json[general_act_name] = template.create_regex(
            general_act)
    return general_regex_json
コード例 #8
0
ファイル: run_chat.py プロジェクト: wendywtchang/adviser
def test_multi(logger: DiasysLogger, language: Language):
    domain = JSONLookupDomain(
        'ImsLecturers',
        json_ontology_file=os.path.join('resources', 'databases', 'ImsLecturers-rules.json'),
        sqllite_db_file=os.path.join('resources', 'databases', 'ImsLecturers-dbase.db'))
    l_nlu = HandcraftedNLU(domain=domain, logger=logger, language=language)
    l_bst = HandcraftedBST(domain=domain, logger=logger)
    l_policy = HandcraftedPolicy(domain=domain, logger=logger)
    l_nlg = HandcraftedNLG(domain=domain, logger=logger, language=language)

    lecturers = DialogSystem(
                            l_nlu,
                            l_bst,
                            l_policy,
                            l_nlg,
                            domain=domain,
                            logger=logger
    )
    domain = JSONLookupDomain(
        'ImsCourses',
        json_ontology_file=os.path.join('resources', 'databases', 'ImsCourses-rules.json'),
        sqllite_db_file=os.path.join('resources', 'databases', 'ImsCourses-dbase.db'))
    c_nlu = HandcraftedNLU(domain=domain, logger=logger, language=language)
    c_bst = HandcraftedBST(domain=domain, logger=logger)
    c_policy = HandcraftedPolicy(domain=domain, logger=logger)
    c_nlg = HandcraftedNLG(domain=domain, logger=logger, language=language)

    courses = DialogSystem(
                        c_nlu,
                        c_bst,
                        c_policy,
                        c_nlg,
                        domain=domain,
                        logger=logger
    )

    multi = HandcraftedMetapolicy(
        subgraphs=[courses, lecturers],
        in_module=ConsoleInput(None, logger=logger, language=language),
        out_module=ConsoleOutput(None, logger=logger),
        logger=logger)
    multi.run_dialog()
コード例 #9
0
ファイル: run_chat.py プロジェクト: zerov4mp/adviser
def load_lecturers_domain(backchannel: bool = False):
    from utils.domain.jsonlookupdomain import JSONLookupDomain
    from services.nlu.nlu import HandcraftedNLU
    from services.nlg.nlg import HandcraftedNLG
    from services.policy import HandcraftedPolicy
    domain = JSONLookupDomain('ImsLecturers', display_name="Lecturers")
    lect_nlu = HandcraftedNLU(domain=domain)
    lect_bst = HandcraftedBST(domain=domain)
    lect_policy = HandcraftedPolicy(domain=domain)
    lect_nlg = load_nlg(backchannel=backchannel, domain=domain)
    return domain, [lect_nlu, lect_bst, lect_policy, lect_nlg]
コード例 #10
0
def test_setup_of_user_informables_lecturers():
    """
    
    Tests if user informable slots are recognized as such

    """
    domain = JSONLookupDomain('ImsLecturers')
    nlu = HandcraftedNLU(domain)

    for user_informable_slot in ['name', 'department', 'position']:
        assert user_informable_slot in nlu.USER_INFORMABLE
コード例 #11
0
    def __init__(self,
                 domain: JSONLookupDomain = None,
                 subgraph=None,
                 logger: DiasysLogger = DiasysLogger()):
        super(MLBST, self).__init__(domain, subgraph, logger=logger)
        self.path_to_data_folder = os.path.join(os.path.realpath(os.curdir),
                                                "modules", "bst", "ml")
        self.primary_key = domain.get_domain_name()

        self.data_mappings = DSTC2Data(
            path_to_data_folder=self.path_to_data_folder,
            preprocess=False,
            load_train_data=False)

        self.inf_trackers = {}
        self.req_trackers = {}

        for inf_slot in domain.get_informable_slots():
            self._load_inf_model(inf_slot)
        for req_slot in domain.get_requestable_slots():
            self._load_req_model(req_slot)
コード例 #12
0
ファイル: nlu.py プロジェクト: zhiyin121/adviser
    def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
                 language: Language = None):
        """
        Loads
            - domain key
            - informable slots
            - requestable slots
            - domain-independent regular expressions
            - domain-specific regualer espressions

        It sets the previous system act to None

        Args:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
        """
        Service.__init__(self, domain=domain)
        self.logger = logger

        self.language = language if language else Language.ENGLISH

        # Getting domain information
        self.domain_name = domain.get_domain_name()
        self.domain_key = domain.get_primary_key()

        # Getting lists of informable and requestable slots
        self.USER_INFORMABLE = domain.get_informable_slots()
        self.USER_REQUESTABLE = domain.get_requestable_slots()

        # Getting the relative path where regexes are stored
        self.base_folder = os.path.join(get_root_dir(), 'resources', 'nlu_regexes')

        # Setting previous system act to None to signal the first turn
        # self.prev_sys_act = None
        self.sys_act_info = {
            'last_act': None, 'lastInformedPrimKeyVal': None, 'lastRequestSlot': None}

        self.language = Language.ENGLISH
        self._initialize()
コード例 #13
0
def test_setup_of_user_requestables_lecturers():
    """

    Tests if user requestable slots are recognized as such

    """
    domain = JSONLookupDomain('ImsLecturers')
    nlu = HandcraftedNLU(domain)

    for user_requestable_slot in [
            'name', 'department', 'office_hours', 'mail', 'phone', 'room',
            'position'
    ]:
        assert user_requestable_slot in nlu.USER_REQUESTABLE
コード例 #14
0
ファイル: nlu.py プロジェクト: wendywtchang/adviser
    def __init__(self,
                 domain: JSONLookupDomain,
                 subgraph=None,
                 logger: DiasysLogger = DiasysLogger(),
                 language: Language = None):
        """
        Loads
            - domain key
            - informable slots
            - requestable slots
            - domain-independent regular expressions
            - domain-specific regualer espressions

        It sets the previous system act to None

        Args:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
            subgraph  {[type]} -- [see modules.Module] (default: {None})
            logger:
        """
        super(HandcraftedNLU, self).__init__(domain, None, logger=logger)

        self.language = language if language else Language.ENGLISH

        # Getting domain information
        self.domain_name = domain.get_domain_name()
        self.domain_key = domain.get_primary_key()

        # Getting lists of informable and requestable slots
        self.USER_INFORMABLE = domain.get_informable_slots()
        self.USER_REQUESTABLE = domain.get_requestable_slots()

        # Getting the relative path where regexes are stored
        self.base_folder = os.path.join(get_root_dir(), 'resources', 'regexes')

        # Setting previous system act to None to signal the first turn
        self.prev_sys_act = None
コード例 #15
0
def train(domain_name: str, log_to_file: bool, seed: int, train_epochs: int, train_dialogs: int,
          eval_dialogs: int, max_turns: int, train_error_rate: float, test_error_rate: float,
          lr: float, eps_start: float, grad_clipping: float, buffer_classname: str, 
          buffer_size: int, use_tensorboard: bool):

    common.init_random(seed=seed)

    file_log_lvl = LogLevel.DIALOGS if log_to_file else LogLevel.NONE
    logger = DiasysLogger(console_log_lvl=LogLevel.RESULTS, file_log_lvl=file_log_lvl)
    if buffer_classname == "prioritized":
        buffer_cls = NaivePrioritizedBuffer
    elif buffer_classname == "uniform":
        buffer_cls = UniformBuffer

    domain = JSONLookupDomain(name=domain_name)
    bst = HandcraftedBST(domain=domain, logger=logger)
    user = HandcraftedUserSimulator(domain, logger=logger)
    noise = SimpleNoise(domain=domain, train_error_rate=train_error_rate, 
                        test_error_rate=test_error_rate, logger=logger)
    policy = DQNPolicy(domain=domain, lr=lr, eps_start=eps_start, 
                        gradient_clipping=grad_clipping, buffer_cls=buffer_cls, 
                        replay_buffer_size=buffer_size, train_dialogs=train_dialogs,
                        logger=logger)
    evaluator = PolicyEvaluator(domain=domain, use_tensorboard=use_tensorboard, 
                                experiment_name=domain_name, logger=logger)
    ds = DialogSystem(policy,
                    user,
                    noise,
                    bst,
                    evaluator)

    for epoch in range(train_epochs):
        # train one epoch
        ds.train()
        evaluator.start_epoch()
        for episode in range(train_dialogs):
            ds.run_dialog(max_length=max_turns)
        evaluator.end_epoch()   # important for statistics!
        ds.num_dialogs = 0  # IMPORTANT for epsilon scheduler in dqnpolicy
        policy.save()       # save model

        # evaluate one epoch
        ds.eval()
        evaluator.start_epoch()
        for episode in range(eval_dialogs):
            ds.run_dialog(max_length=max_turns)
        evaluator.end_epoch()   # important for statistics!
        ds.num_dialogs = 0 # IMPORTANT for epsilon scheduler in dqnpolicy
コード例 #16
0
    def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
                 max_turns: int = 25):
        """
        Initializes the policy

        Arguments:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain

        """
        self.first_turn = True
        Service.__init__(self, domain=domain)
        self.current_suggestions = []  # list of current suggestions
        self.s_index = 0  # the index in current suggestions for the current system reccomendation
        self.domain_key = domain.get_primary_key()
        self.logger = logger
        self.max_turns = max_turns
コード例 #17
0
ファイル: courses_regexes_test.py プロジェクト: kdbz/adviser
def test_request_time_slot():
    """
	
	Tests exemplary whether a given synonym, i.e. a user utterance, is recognized as belonging to a certain slot

	"""
    domain = JSONLookupDomain('ImsCourses')
    nlu = HandcraftedNLU(domain)

    act_out = UserAct()
    act_out.type = UserActionType.Request
    act_out.slot = "time_slot"

    usr_utt = nlu.extract_user_acts(nlu, user_utterance='at what time')
    assert 'user_acts' in usr_utt
    assert usr_utt['user_acts'][0] == act_out
コード例 #18
0
def test_inform_lecturers_department():
    """

    Tests exemplary whether a given synonym, i.e. a user utterance, is recognized as belonging to a certain slot-value pair

    """
    domain = JSONLookupDomain('ImsLecturers')
    nlu = HandcraftedNLU(domain)

    act_out = UserAct()
    act_out.type = UserActionType.Inform
    act_out.slot = "department"
    act_out.value = "external"

    usr_utt = nlu.extract_user_acts(nlu, user_utterance='informatics')
    assert 'user_acts' in usr_utt
    assert usr_utt['user_acts'][0] == act_out
コード例 #19
0
def test_inform_lecturers_name():
    """
    
    Tests exemplary whether a given synonym, i.e. a user utterance, is recognized as belonging to a certain slot-value pair

    """
    domain = JSONLookupDomain('ImsLecturers')
    nlu = HandcraftedNLU(domain)

    act_out = UserAct()
    act_out.type = UserActionType.Inform
    act_out.slot = "name"
    act_out.value = "apl. prof. dr. agatha christie"

    usr_utt = nlu.extract_user_acts(nlu, user_utterance='agatha christie')
    assert 'user_acts' in usr_utt
    assert usr_utt['user_acts'][0] == act_out
コード例 #20
0
def test_request_lecturers_phone():
    """
    
    Tests exemplary whether a given synonym, i.e. a user utterance, is recognized as belonging to a certain slot

    """
    domain = JSONLookupDomain('ImsLecturers')
    nlu = HandcraftedNLU(domain)

    act_out = UserAct()
    act_out.type = UserActionType.Request
    act_out.slot = "phone"

    usr_utt = nlu.extract_user_acts(
        nlu, user_utterance='can you tell me the phone number')
    assert 'user_acts' in usr_utt
    assert usr_utt['user_acts'][0] == act_out
コード例 #21
0
ファイル: courses_regexes_test.py プロジェクト: kdbz/adviser
def test_inform_courses_language():
    """
	
	Tests exemplary whether a given synonym, i.e. a user utterance, is recognized as belonging to a certain slot-value pair

	"""
    domain = JSONLookupDomain('ImsCourses')
    nlu = HandcraftedNLU(domain)

    act_out = UserAct()
    act_out.type = UserActionType.Inform
    act_out.slot = "lang"
    act_out.value = "de"

    usr_utt = nlu.extract_user_acts(nlu, user_utterance='german')
    assert 'user_acts' in usr_utt
    assert usr_utt['user_acts'][0] == act_out
コード例 #22
0
def test_request_lecturers_office_hours():
    """
    
    Tests exemplary whether a given synonym, i.e. a user utterance, is recognized as belonging to a certain slot

    """
    domain = JSONLookupDomain('ImsLecturers')
    nlu = HandcraftedNLU(domain)

    act_out = UserAct()
    act_out.type = UserActionType.Request
    act_out.slot = "office_hours"

    usr_utt = nlu.extract_user_acts(
        nlu, user_utterance='when are the consultation hours')
    assert 'user_acts' in usr_utt
    assert usr_utt['user_acts'][0] == act_out
コード例 #23
0
ファイル: courses_regexes_test.py プロジェクト: kdbz/adviser
def test_setup_of_user_requestables_courses():
    """

    Tests if user requestable slots are recognized as such

    """
    domain = JSONLookupDomain('ImsCourses')
    nlu = HandcraftedNLU(domain)

    for user_requestable_slot in [
            'applied_nlp', 'bachelor', 'cognitive_science', 'course_type',
            'deep_learning', 'ects', 'elective', 'extendable', 'lecturer',
            'lang', 'linguistics', 'machine_learning', 'master', 'module_id',
            'module_name', 'name', 'obligatory_attendance', 'oral_exam',
            'participation_limit', 'presentation', 'programming', 'project',
            'report', 'semantics', 'speech', 'statistics', 'sws', 'syntax',
            'turn', 'written_exam'
    ]:
        assert user_requestable_slot in nlu.USER_REQUESTABLE
コード例 #24
0
    def __init__(self,
                 domain: JSONLookupDomain,
                 subgraph=None,
                 logger: DiasysLogger = DiasysLogger()):
        """
        Initializes the policy

        Arguments:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain

        """
        super(HandcraftedPolicy, self).__init__(domain,
                                                subgraph=None,
                                                logger=logger)
        self.turn = 0
        self.last_action = None
        self.current_suggestions = []  # list of current suggestions
        self.s_index = 0  # the index in current suggestions for the current system reccomendation
        self.domain_key = domain.get_primary_key()
コード例 #25
0
ファイル: courses_regexes_test.py プロジェクト: kdbz/adviser
def test_multiple_user_acts_courses():
    """
    
    Tests exemplary whether a given sentence with multiple user acts is understood properly
    
    """
    domain = JSONLookupDomain('ImsCourses')
    nlu = HandcraftedNLU(domain)

    usr_utt = nlu.extract_user_acts(
        nlu,
        user_utterance='Hi, I want a course that is related to linguistics')
    assert 'user_acts' in usr_utt

    act_out = UserAct()
    act_out.type = UserActionType.Hello
    assert usr_utt['user_acts'][0] == act_out

    act_out = UserAct()
    act_out.type = UserActionType.Inform
    act_out.slot = "linguistics"
    act_out.value = "true"
    assert usr_utt['user_acts'][1] == act_out
コード例 #26
0
def test_multiple_user_acts_lecturers():
    """
    
    Tests exemplary whether a given sentence with multiple user acts is understood properly
    
    """
    domain = JSONLookupDomain('ImsLecturers')
    nlu = HandcraftedNLU(domain)

    usr_utt = nlu.extract_user_acts(
        nlu,
        user_utterance=
        'Hi, I want a lecturer who is responsible for gender issues')
    assert 'user_acts' in usr_utt

    act_out = UserAct()
    act_out.type = UserActionType.Hello
    assert usr_utt['user_acts'][0] == act_out

    act_out = UserAct()
    act_out.type = UserActionType.Inform
    act_out.slot = "position"
    act_out.value = "gender"
    assert usr_utt['user_acts'][1] == act_out
コード例 #27
0
def test_hdc_usersim(domain_name: str, logger: DiasysLogger, eval_epochs: int,
                     eval_dialogs: int, max_turns: int, test_error: float,
                     use_tensorboard: bool):

    domain = JSONLookupDomain(domain_name)
    bst = HandcraftedBST(domain=domain, logger=logger)
    user_simulator = HandcraftedUserSimulator(domain, logger=logger)
    noise = SimpleNoise(domain=domain,
                        train_error_rate=0.,
                        test_error_rate=test_error,
                        logger=logger)
    policy = HandcraftedPolicy(domain=domain, logger=logger)
    evaluator = PolicyEvaluator(domain=domain,
                                use_tensorboard=use_tensorboard,
                                experiment_name='hdc_eval',
                                logger=logger)
    ds = DialogSystem(policy, user_simulator, noise, bst, evaluator)
    ds.eval()

    for epoch in range(eval_epochs):
        evaluator.start_epoch()
        for episode in range(eval_dialogs):
            ds.run_dialog(max_length=max_turns)
        evaluator.end_epoch()
コード例 #28
0
    TRAIN_EPISODES = 1000
    NUM_TEST_SEEDS = 10
    EVAL_EPISODES = 500
    MAX_TURNS = -1
    TRAIN_EPOCHS = 10

    # get #num_test_seeds random seeds
    random_seeds = []
    for i in range(NUM_TEST_SEEDS):
        random_seeds.append(random.randint(0, 2**32-1))

    results = {}
    for seed in random_seeds:
        common.init_once = False
        common.init_random(seed=seed)    
        domain = JSONLookupDomain('ImsCourses')
        bst = HandcraftedBST(domain=domain)
        user_simulator = HandcraftedUserSimulator(domain=domain)
        policy = None

        policy= DQNPolicy(domain=domain)
        evaluator = PolicyEvaluator(domain=domain, use_tensorboard=True, 
                                experiment_name='eval_rl_imscourses' + str(seed))
        ds = DialogSystem(policy,
                        user_simulator,
                        bst,
                        evaluator
                            )
            
        for i in range(TRAIN_EPOCHS):
            ds.train()
コード例 #29
0
from services.ust.ust import HandcraftedUST
from services.nlg import HandcraftedNLG
from services.backchannel import AcousticBackchanneller
from services.nlg import BackchannelHandcraftedNLG
from services.nlg import HandcraftedEmotionNLG
from services.nlu import HandcraftedNLU
from services.policy import HandcraftedPolicy
from services.policy.affective_policy import EmotionPolicy
from services.service import DialogSystem
from utils.domain.jsonlookupdomain import JSONLookupDomain
from utils.logger import DiasysLogger, LogLevel
from services.simulator.emotion_simulator import EmotionSimulator
from utils.userstate import EmotionType

# load domains
lecturers = JSONLookupDomain(name='ImsLecturers', display_name="Lecturers")
weather = WeatherDomain()
mensa = MensaDomain()

# only debug logging
conversation_log_dir = "./conversation_logs"
os.makedirs(f"./{conversation_log_dir}/", exist_ok=True)
logger = DiasysLogger(file_log_lvl=LogLevel.NONE,
                      console_log_lvl=LogLevel.DIALOGS,
                      logfile_basename="full_log")

# input modules
user_in = ConsoleInput(conversation_log_dir=conversation_log_dir)
user_out = ConsoleOutput()
recorder = SpeechRecorder(conversation_log_dir=conversation_log_dir)
speech_in_decoder = SpeechInputDecoder(
コード例 #30
0
ファイル: train_dqnpolicy.py プロジェクト: kdbz/adviser
def train(domain_name: str, log_to_file: bool, seed: int, train_epochs: int, train_dialogs: int,
          eval_dialogs: int, max_turns: int, train_error_rate: float, test_error_rate: float,
          lr: float, eps_start: float, grad_clipping: float, buffer_classname: str,
          buffer_size: int, use_tensorboard: bool):

    """
        Training loop for the RL policy, for information on the parameters, look at the descriptions
        of commandline arguments in the "if main" below
    """
    seed = seed if seed != -1 else None
    common.init_random(seed=seed)

    file_log_lvl = LogLevel.DIALOGS if log_to_file else LogLevel.NONE
    logger = DiasysLogger(console_log_lvl=LogLevel.RESULTS, file_log_lvl=file_log_lvl)

    summary_writer = SummaryWriter(log_dir='logs') if use_tensorboard else None
    
    if buffer_classname == "prioritized":
        buffer_cls = NaivePrioritizedBuffer
    elif buffer_classname == "uniform":
        buffer_cls = UniformBuffer

    domain = JSONLookupDomain(name=domain_name)
    
    bst = HandcraftedBST(domain=domain, logger=logger)
    user = HandcraftedUserSimulator(domain, logger=logger)
    # noise = SimpleNoise(domain=domain, train_error_rate=train_error_rate,
    #                     test_error_rate=test_error_rate, logger=logger)
    policy = DQNPolicy(domain=domain, lr=lr, eps_start=eps_start,
                    gradient_clipping=grad_clipping, buffer_cls=buffer_cls,
                    replay_buffer_size=buffer_size, train_dialogs=train_dialogs,
                    logger=logger, summary_writer=summary_writer)
    evaluator = PolicyEvaluator(domain=domain, use_tensorboard=use_tensorboard,
                                experiment_name=domain_name, logger=logger,
                                summary_writer=summary_writer)
    ds = DialogSystem(services=[user, bst, policy, evaluator], protocol='tcp')
    # ds.draw_system_graph()

    error_free = ds.is_error_free_messaging_pipeline()
    if not error_free:
        ds.print_inconsistencies()

    for j in range(train_epochs):
        # START TRAIN EPOCH
        evaluator.train()
        policy.train()
        evaluator.start_epoch()
        for episode in range(train_dialogs):
            if episode % 100 == 0:
                print("DIALOG", episode)
            logger.dialog_turn("\n\n!!!!!!!!!!!!!!!! NEW DIALOG !!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n")
            ds.run_dialog(start_signals={f'user_acts/{domain.get_domain_name()}': []})
        evaluator.end_epoch()
        policy.save()

        # START EVAL EPOCH
        evaluator.eval()
        policy.eval()
        evaluator.start_epoch()
        for episode in range(eval_dialogs):
            logger.dialog_turn("\n\n!!!!!!!!!!!!!!!! NEW DIALOG !!!!!!!!!!!!!!!!!!!!!!!!!!!!\n\n")
            ds.run_dialog(start_signals={f'user_acts/{domain.get_domain_name()}': []})
        evaluator.end_epoch()
    ds.shutdown()