Example #1
0
 def __init__(self, domain : Domain, turn_reward=-1, success_reward=20, 
              logger : DiasysLogger =  DiasysLogger()):
     assert turn_reward <= 0.0, 'the turn reward should be negative'
     self.domain = domain
     self.turn_reward = turn_reward
     self.success_reward = success_reward
     self.logger = logger
Example #2
0
 def __init__(self, subgraphs, in_module, out_module, logger: DiasysLogger = DiasysLogger()):
     domains = [graph.domain.get_domain_name() for graph in subgraphs]
     super(HandcraftedMetapolicy, self).__init__(domains, None, logger=logger)
     self.subgraphs = subgraphs
     self.active_domains = domains
     self.input_module = in_module
     self.output_module = out_module
     self.rewrite = False
     self.offers = {}  # track offers made for active domains
Example #3
0
 def __init__(self,
              domain=None,
              subgraph=None,
              logger: DiasysLogger = DiasysLogger()):
     super(HandcraftedBST, self).__init__(domain, subgraph, logger=logger)
     # Informables and Requestables are extracted and provided with
     # probabilities e.g. state['beliefs']['food']['italian'] = 0.0
     self.inform_scores = {}
     self.request_slots = {}
     self.primary_key = domain.get_primary_key()
Example #4
0
    def __init__(self,
                 *modules,
                 domain=None,
                 logger: DiasysLogger = DiasysLogger()):
        self.domain = domain
        self.logger = logger
        self.modules = modules

        self.is_training = False
        self.num_dialogs = 0
        self.num_turns = 0
Example #5
0
 def __init__(self,
              domain: Domain,
              turn_reward=-1,
              success_reward=20,
              logger: DiasysLogger = DiasysLogger()):
     super(ObjectiveReachedEvaluator, self).__init__(domain)
     assert turn_reward <= 0.0, 'the turn reward should be negative'
     self.domain = domain
     self.turn_reward = turn_reward
     self.success_reward = success_reward
     self.logger = logger
Example #6
0
 def __init__(self, domain: Domain, subgraph: dict = None, template_file: str = None,
              logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
              language: Language = None):
     """Constructor mainly extracts methods and rules from the template file"""
     super(HandcraftedNLG, self).__init__(domain, subgraph, logger=logger)
 
     self.language = language if language else Language.ENGLISH
     self.template_english = template_file
     # TODO: at some point if we expand languages, maybe make kwargs? --LV
     self.template_german = template_file_german
     self.domain = domain
     self.template_filename = None
     self.templates = None
Example #7
0
    def __init__(self,
                 domain: JSONLookupDomain = None,
                 logger: DiasysLogger = DiasysLogger()):
        """
        Initializes the policy

        Arguments:
            domain (JSONLookupDomain): the domain that the affective policy should operate in

        """
        self.first_turn = True
        Service.__init__(self, domain=domain)
        self.logger = logger
Example #8
0
    def __init__(self,
                 domain: Domain,
                 sub_topic_domains: Dict[str, str] = {},
                 template_file: str = None,
                 logger: DiasysLogger = DiasysLogger(),
                 template_file_german: str = None,
                 language: Language = None):
        """Constructor mainly extracts methods and rules from the template file"""
        HandcraftedNLG.__init__(self,
                                domain,
                                template_file=None,
                                logger=DiasysLogger(),
                                template_file_german=None,
                                language=None,
                                sub_topic_domains=sub_topic_domains)

        # class_int_mapping = {0: b'no_bc', 1: b'assessment', 2: b'continuer'}
        self.backchannels = {
            0: [''],
            1: ['Okay. ', 'Yeah. '],
            2: ['Um-hum. ', 'Uh-huh. ']
        }
Example #9
0
    def __init__(self, domain: Domain, sub_topic_domains={}, template_file: str = None,
                 logger: DiasysLogger = DiasysLogger(), template_file_german: str = None,
                 emotions: List[str] = [], debug_logger = None):
        """Constructor mainly extracts methods and rules from the template file"""
        Service.__init__(self, domain=domain, sub_topic_domains=sub_topic_domains, debug_logger=debug_logger)

        self.domain = domain
        self.template_filename = template_file
        self.templates = {}
        self.logger = logger
        self.emotions = emotions

        self._initialise_templates()
Example #10
0
 def __init__(self,
              domain: JSONLookupDomain,
              train_error_rate: float,
              test_error_rate: float,
              pdf: List[float] = [0.7, 0.3, 0.0],
              logger: DiasysLogger = DiasysLogger()):
     super(Noise, self).__init__(domain, logger=logger)
     self.domain = domain
     self.train_error_rate = train_error_rate
     self.test_error_rate = test_error_rate
     self.backup = train_error_rate, test_error_rate
     self.pdf = pdf
     self.pdf_wo_value = [
         float(i) / sum(self.pdf[1:]) for i in self.pdf[1:]
     ]
Example #11
0
    def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
                 max_turns: int = 25):
        """
        Initializes the policy

        Arguments:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain

        """
        self.first_turn = True
        Service.__init__(self, domain=domain)
        self.current_suggestions = []  # list of current suggestions
        self.s_index = 0  # the index in current suggestions for the current system reccomendation
        self.domain_key = domain.get_primary_key()
        self.logger = logger
        self.max_turns = max_turns
Example #12
0
    def __init__(self,
                 domain: Domain,
                 subgraph: dict = None,
                 use_tensorboard=False,
                 experiment_name: str = '',
                 turn_reward=-1,
                 success_reward=20,
                 logger: DiasysLogger = DiasysLogger(),
                 summary_writer=None):
        """
        Keyword Arguments:
            use_tensorboard {bool} -- [If true, metrics will be written to
                                       tensorboard in a *runs* directory]
                                       (default: {False})
            experiment_name {str} -- [Name suffix for the log files]
                                      (default: {''})
            turn_reward {float} -- [Reward for one turn - usually negative to
                                    penalize dialog length] (default: {-1})
            success_reward {float} -- [Reward of the final transition if the
                                       dialog goal was reached] (default: {20})
        """
        super(PolicyEvaluator, self).__init__(domain)
        self.logger = logger
        self.epoch = 0
        self.evaluator = ObjectiveReachedEvaluator(
            domain,
            turn_reward=turn_reward,
            success_reward=success_reward,
            logger=logger)

        self.writer = summary_writer

        self.total_train_dialogs = 0
        self.total_eval_dialogs = 0

        self.epoch_train_dialogs = 0
        self.epoch_eval_dialogs = 0
        self.train_rewards = []
        self.eval_rewards = []
        self.train_success = []
        self.eval_success = []
        self.train_turns = []
        self.eval_turns = []
        self.is_training = False
Example #13
0
    def __init__(self,
                 domain: JSONLookupDomain,
                 subgraph=None,
                 logger: DiasysLogger = DiasysLogger()):
        """
        Initializes the policy

        Arguments:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain

        """
        super(HandcraftedPolicy, self).__init__(domain,
                                                subgraph=None,
                                                logger=logger)
        self.turn = 0
        self.last_action = None
        self.current_suggestions = []  # list of current suggestions
        self.s_index = 0  # the index in current suggestions for the current system reccomendation
        self.domain_key = domain.get_primary_key()
Example #14
0
    def __init__(self, domain: LookupDomain, \
        logger: DiasysLogger = DiasysLogger(), device: str = 'cpu'):
        """Creates neural networks for semantic parsing and other required utils

        Args:
            domain: the QA domain
            logger: the logger
            device: PyTorch device name
        """
        Service.__init__(self, domain=domain, debug_logger=logger)
        self.device = torch.device(device)
        self.nn_relation = self._load_relation_model()
        self.nn_entity = self._load_entity_model()
        self.nn_direction = self._load_direction_model()

        self.tags = self._load_tag_set()

        self.max_seq_len = 40
        self.embedding_creator = BertEmbedding(max_seq_length=self.max_seq_len)
Example #15
0
    def __init__(self,
                 domain: JSONLookupDomain = None,
                 subgraph=None,
                 logger: DiasysLogger = DiasysLogger()):
        super(MLBST, self).__init__(domain, subgraph, logger=logger)
        self.path_to_data_folder = os.path.join(os.path.realpath(os.curdir),
                                                "modules", "bst", "ml")
        self.primary_key = domain.get_domain_name()

        self.data_mappings = DSTC2Data(
            path_to_data_folder=self.path_to_data_folder,
            preprocess=False,
            load_train_data=False)

        self.inf_trackers = {}
        self.req_trackers = {}

        for inf_slot in domain.get_informable_slots():
            self._load_inf_model(inf_slot)
        for req_slot in domain.get_requestable_slots():
            self._load_req_model(req_slot)
Example #16
0
    def __init__(self, domain: JSONLookupDomain, logger: DiasysLogger = DiasysLogger(),
                 language: Language = None):
        """
        Loads
            - domain key
            - informable slots
            - requestable slots
            - domain-independent regular expressions
            - domain-specific regualer espressions

        It sets the previous system act to None

        Args:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
        """
        Service.__init__(self, domain=domain)
        self.logger = logger

        self.language = language if language else Language.ENGLISH

        # Getting domain information
        self.domain_name = domain.get_domain_name()
        self.domain_key = domain.get_primary_key()

        # Getting lists of informable and requestable slots
        self.USER_INFORMABLE = domain.get_informable_slots()
        self.USER_REQUESTABLE = domain.get_requestable_slots()

        # Getting the relative path where regexes are stored
        self.base_folder = os.path.join(get_root_dir(), 'resources', 'nlu_regexes')

        # Setting previous system act to None to signal the first turn
        # self.prev_sys_act = None
        self.sys_act_info = {
            'last_act': None, 'lastInformedPrimKeyVal': None, 'lastRequestSlot': None}

        self.language = Language.ENGLISH
        self._initialize()
Example #17
0
    def __init__(self,
                 domain: Domain,
                 template_file: str = None,
                 sub_topic_domains: Dict[str, str] = {},
                 logger: DiasysLogger = DiasysLogger(),
                 template_file_german: str = None,
                 language: Language = None):
        """Constructor mainly extracts methods and rules from the template file"""
        Service.__init__(self,
                         domain=domain,
                         sub_topic_domains=sub_topic_domains)

        self.language = language if language else Language.ENGLISH
        self.template_english = template_file
        # TODO: at some point if we expand languages, maybe make kwargs? --LV
        self.template_german = template_file_german
        self.domain = domain
        self.template_filename = None
        self.templates = None
        self.logger = logger

        self.language = Language.ENGLISH
        self._initialise_language(self.language)
Example #18
0
    def __init__(self,
                 domain: JSONLookupDomain,
                 subgraph=None,
                 logger: DiasysLogger = DiasysLogger(),
                 language: Language = None):
        """
        Loads
            - domain key
            - informable slots
            - requestable slots
            - domain-independent regular expressions
            - domain-specific regualer espressions

        It sets the previous system act to None

        Args:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain
            subgraph  {[type]} -- [see modules.Module] (default: {None})
            logger:
        """
        super(HandcraftedNLU, self).__init__(domain, None, logger=logger)

        self.language = language if language else Language.ENGLISH

        # Getting domain information
        self.domain_name = domain.get_domain_name()
        self.domain_key = domain.get_primary_key()

        # Getting lists of informable and requestable slots
        self.USER_INFORMABLE = domain.get_informable_slots()
        self.USER_REQUESTABLE = domain.get_requestable_slots()

        # Getting the relative path where regexes are stored
        self.base_folder = os.path.join(get_root_dir(), 'resources', 'regexes')

        # Setting previous system act to None to signal the first turn
        self.prev_sys_act = None
Example #19
0
 def __init__(self, domain: LookupDomain, logger: DiasysLogger = DiasysLogger()):
     # only call super class' constructor
     Service.__init__(self, domain=domain, debug_logger=logger)
Example #20
0
 def __init__(self, domain: Domain, subgraph: dict = None,
              logger: DiasysLogger = DiasysLogger()):
     self.domain = domain
     self.subgraph = subgraph
     self.logger = logger
     self.is_training = False
from services.service import DialogSystem
from utils.domain.jsonlookupdomain import JSONLookupDomain
from utils.logger import DiasysLogger, LogLevel
from services.simulator.emotion_simulator import EmotionSimulator
from utils.userstate import EmotionType

# load domains
lecturers = JSONLookupDomain(name='ImsLecturers', display_name="Lecturers")
weather = WeatherDomain()
mensa = MensaDomain()

# only debug logging
conversation_log_dir = "./conversation_logs"
os.makedirs(f"./{conversation_log_dir}/", exist_ok=True)
logger = DiasysLogger(file_log_lvl=LogLevel.NONE,
                      console_log_lvl=LogLevel.DIALOGS,
                      logfile_basename="full_log")

# input modules
user_in = ConsoleInput(conversation_log_dir=conversation_log_dir)
user_out = ConsoleOutput()
recorder = SpeechRecorder(conversation_log_dir=conversation_log_dir)
speech_in_decoder = SpeechInputDecoder(
    conversation_log_dir=conversation_log_dir,
    use_cuda=False)  #RemoteService(identifier="asr")

# feature extraction
d_tracker = DomainTracker(domains=[lecturers, weather, mensa],
                          greet_on_first_turn=True)
speech_in_feature_extractor = SpeechInputFeatureExtractor()
speach_feats = SpeechFeatureExtractor()
Example #22
0
    def __init__(self,
                 domain: JSONLookupDomain,
                 buffer_cls=UniformBuffer,
                 buffer_size=6000,
                 batch_size=64,
                 discount_gamma=0.99,
                 max_turns: int = 25,
                 include_confreq=False,
                 logger: DiasysLogger = DiasysLogger(),
                 include_select: bool = False,
                 device=torch.device('cpu'),
                 obj_evaluator: ObjectiveReachedEvaluator = None):
        """
        Creates state- and action spaces, initializes experience replay
        buffers.

        Arguments:
            domain {domain.jsonlookupdomain.JSONLookupDomain} -- Domain

        Keyword Arguments:
            subgraph {[type]} -- [see services.Module] (default: {None})
            buffer_cls {services.policy.rl.experience_buffer.Buffer}
            -- [Experience replay buffer *class*, **not** an instance - will be
                initialized by this constructor!] (default: {UniformBuffer})
            buffer_size {int} -- [see services.policy.rl.experience_buffer.
                                  Buffer] (default: {6000})
            batch_size {int} -- [see services.policy.rl.experience_buffer.
                                  Buffer] (default: {64})
            discount_gamma {float} -- [Discount factor] (default: {0.99})
            include_confreq {bool} -- [Use confirm_request actions]
                                       (default: {False})
        """

        self.device = device
        self.sys_state = {
            "lastInformedPrimKeyVal": None,
            "lastActionInformNone": False,
            "offerHappened": False,
            'informedPrimKeyValsSinceNone': []
        }

        self.max_turns = max_turns
        self.logger = logger
        self.domain = domain
        # setup evaluator for training
        self.evaluator = obj_evaluator  #  ObjectiveReachedEvaluator(domain, logger=logger)

        self.buffer_size = buffer_size
        self.batch_size = batch_size
        self.discount_gamma = discount_gamma

        self.writer = None

        # get state size
        self.state_dim = self.beliefstate_dict_to_vector(
            BeliefState(domain)._init_beliefstate()).size(1)
        self.logger.info("state space dim: " + str(self.state_dim))

        # get system action list
        self.actions = [
            "inform_byname",  # TODO rename to 'bykey'
            "inform_alternatives",
            "reqmore"
        ]
        # TODO badaction
        for req_slot in self.domain.get_system_requestable_slots():
            self.actions.append('request#' + req_slot)
            self.actions.append('confirm#' + req_slot)
            if include_select:
                self.actions.append('select#' + req_slot)
            if include_confreq:
                for conf_slot in self.domain.get_system_requestable_slots():
                    if not req_slot == conf_slot:
                        # skip case where confirm slot = request slot
                        self.actions.append('confreq#' + conf_slot + '#' +
                                            req_slot)
        self.action_dim = len(self.actions)
        # don't include closingmsg in learnable actions
        self.actions.append('closingmsg')
        # self.actions.append("closingmsg")
        self.logger.info("action space dim: " + str(self.action_dim))

        self.primary_key = self.domain.get_primary_key()

        # init replay memory
        self.buffer = buffer_cls(buffer_size,
                                 batch_size,
                                 self.state_dim,
                                 discount_gamma=discount_gamma,
                                 device=device)
        self.sys_state = {}

        self.last_sys_act = None
Example #23
0
    def __init__(self, domain: JSONLookupDomain,
                 architecture: NetArchitecture = NetArchitecture.DUELING,
                 hidden_layer_sizes: List[int] = [256, 700, 700],  # vanilla architecture
                 shared_layer_sizes: List[int] = [256], value_layer_sizes : List[int] = [300,300],
                 advantage_layer_sizes: List[int] = [400, 400],  # dueling architecture
                 lr: float = 0.0001, discount_gamma: float = 0.99,
                 target_update_rate: int = 3,
                 replay_buffer_size: int = 8192, batch_size: int = 64,
                 buffer_cls: Type[Buffer] = NaivePrioritizedBuffer,
                 eps_start: float = 0.3, eps_end: float = 0.0,
                 l2_regularisation: float = 0.0, gradient_clipping: float = 5.0,
                 p_dropout: float = 0.0, training_frequency: int = 2, train_dialogs: int = 1000,
                 include_confreq: bool = False, logger: DiasysLogger =  DiasysLogger()):
        """
        Args:
            target_update_rate: if 1, vanilla dqn update
                                if > 1, double dqn with specified target update
                                rate
        """
        super(DQNPolicy, self).__init__(domain, buffer_cls=buffer_cls,
                 buffer_size=replay_buffer_size, batch_size=batch_size,
                 discount_gamma=discount_gamma, include_confreq=include_confreq, logger=logger)

        self.training_frequency = training_frequency
        self.train_dialogs = train_dialogs
        self.lr = lr
        self.gradient_clipping = gradient_clipping
        if gradient_clipping > 0.0:
            self.logger.info("Gradient Clipping: " + str(gradient_clipping))
        self.target_update_rate = target_update_rate

        self.epsilon_start = eps_start
        self.epsilon_end = eps_end

        # Select network architecture
        if architecture == NetArchitecture.VANILLA:
            self.logger.info("Architecture: Vanilla")
            self.model = DQN(self.state_dim, self.action_dim,
                            hidden_layer_sizes=hidden_layer_sizes,
                            dropout_rate=p_dropout)
        else:
            self.logger.info("Architecture: Dueling")
            self.model = DuelingDQN(self.state_dim, self.action_dim,
                                    shared_layer_sizes=shared_layer_sizes,
                                    value_layer_sizes=value_layer_sizes,
                                    advantage_layer_sizes=advantage_layer_sizes,
                                    dropout_rate=p_dropout)
        # Select network update
        self.target_model = None
        if target_update_rate > 1:
            self.logger.info("Update: Double")
            if architecture == NetArchitecture.VANILLA:
                self.target_model = copy.deepcopy(self.model)
        else:
            self.logger.info("Update: Vanilla")

        self.optim = optim.Adam(self.model.parameters(), lr=lr, weight_decay=l2_regularisation)
        self.loss_fun = nn.SmoothL1Loss(reduction='none')
        #self.loss_fun = nn.MSELoss(reduction='none')

        self.train_call_count = 0
        self.total_train_dialogs = 0
        self.epsilon = self.epsilon_start
Example #24
0
    def __init__(self, domain: Domain, logger: DiasysLogger = DiasysLogger()):
        super(HandcraftedUserSimulator, self).__init__(domain)

        # possible system actions
        self.receive_options = {
            SysActionType.Welcome: self._receive_welcome,
            SysActionType.InformByName: self._receive_informbyname,
            SysActionType.InformByAlternatives:
            self._receive_informbyalternatives,
            SysActionType.Request: self._receive_request,
            SysActionType.Confirm: self._receive_confirm,
            SysActionType.Select: self._receive_select,
            SysActionType.RequestMore: self._receive_requestmore,
            SysActionType.Bad: self._receive_bad,
            SysActionType.ConfirmRequest: self._receive_confirmrequest
        }

        # parse config file
        self.logger = logger
        self.config = configparser.ConfigParser(inline_comment_prefixes=('#',
                                                                         ';'))
        self.config.optionxform = str
        self.config.read(
            os.path.join(os.path.abspath(os.path.dirname(__file__)),
                         'usermodel.cfg'))

        self.parameters = {}
        # goal
        self.parameters['goal'] = {}
        for key in self.config["goal"]:
            val = self.config.get("goal", key)
            self.parameters['goal'][key] = float(val)

        # usermodel
        self.parameters['usermodel'] = {}
        for key in self.config["usermodel"]:
            val = self.config.get("usermodel", key)
            if key in ['patience']:
                # patience will be sampled on begin of each dialog
                self.parameters['usermodel'][key] = [
                    int(x)
                    for x in (val.replace(' ', '').strip('[]').split(','))
                ]
            else:
                if val.startswith("[") and val.endswith("]"):
                    # value is a list to sample the probability from
                    self.parameters['usermodel'][
                        key] = common.numpy.random.uniform(*[
                            float(x) for x in val.replace(' ', '').strip(
                                '[]').split(',')
                        ])
                else:
                    # value is the probability
                    self.parameters['usermodel'][key] = float(val)

        # member declarations
        self.turn = 0
        self.domain = domain
        self.dialog_patience = None
        self.patience = None
        self.last_user_actions = None
        self.last_system_action = None
        self.excluded_venues = []

        # member definitions
        self.goal = Goal(domain, self.parameters['goal'])
        self.agenda = Agenda()
        self.num_actions_next_turn = -1
Example #25
0
    conversation_log_dir = './conversation_logs'
    speech_log_dir = None
    if file_log_lvl == LogLevel.DIALOGS:
        # log user audio, system audio and complete conversation
        import time
        from math import floor

        print("This Adviser call will log all your interactions to files.\n")
        if not os.path.exists(f"./{conversation_log_dir}"):
            os.mkdir(f"./{conversation_log_dir}/")
        conversation_log_dir = "./" + conversation_log_dir + "/{}/".format(
            floor(time.time()))
        os.mkdir(conversation_log_dir)
        speech_log_dir = conversation_log_dir
    logger = DiasysLogger(file_log_lvl=file_log_lvl,
                          console_log_lvl=log_lvl,
                          logfile_folder=conversation_log_dir,
                          logfile_basename="full_log")

    # load domain specific services
    if 'lecturers' in args.domains:
        l_domain, l_services = load_lecturers_domain(backchannel=args.bc)
        domains.append(l_domain)
        services.extend(l_services)
    if 'weather' in args.domains:
        w_domain, w_services = load_weather_domain()
        domains.append(w_domain)
        services.extend(w_services)
    if 'mensa' in args.domains:
        m_domain, m_services = load_mensa_domain(backchannel=args.bc)
        domains.append(m_domain)
        services.extend(m_services)
Example #26
0
 def __init__(self, domain, logger=DiasysLogger()):
     # only calls super class' constructor
     super(VVSNLU, self).__init__(domain, debug_logger=logger)
Example #27
0
    def __init__(self, domain: Domain, logger: DiasysLogger = DiasysLogger()):
        super(UserSimulator, self).__init__(domain, logger=logger)

        self.domain = domain
        self.goal = None
        self.turn = None
Example #28
0
import torch
import torch.optim as optim
import torch.nn as nn
from tensorboardX import SummaryWriter

from modules.policy.policy_rl import RLPolicy
from modules.policy.rl.common import DEVICE
from modules.policy.rl.experience_buffer import UniformBuffer, NaivePrioritizedBuffer
from modules.policy.rl.dqn import DQN, DuelingDQN, NetArchitecture
from utils.useract import UserActionType, UserAct
from utils.sysact import SysAct, SysActionType
from utils.beliefstate import BeliefState
from utils.logger import DiasysLogger
from utils import Goal, common
logger = DiasysLogger()

MAX_TURNS = 25


class DQNPolicy(RLPolicy):
    def __init__(
            self,
            domain,
            architecture: NetArchitecture = NetArchitecture.DUELING,
            hidden_layer_sizes=[300, 700, 700],  # vanilla architecture
            shared_layer_sizes=[300],
            value_layer_sizes=[300, 300],
            advantage_layer_sizes=[400, 400],  # dueling architecture
            lr=0.0001,
            discount_gamma=0.99,
Example #29
0
 def __init__(self,
              domain: LookupDomain,
              logger: DiasysLogger = DiasysLogger()):
     # only call super class' constructor
     HandcraftedPolicy.__init__(self, domain=domain, logger=logger)
Example #30
0
from tensorboardX import SummaryWriter
from services.policy.rl.experience_buffer import NaivePrioritizedBuffer
from services.simulator import HandcraftedUserSimulator
from services.policy import DQNPolicy
from services.stats.evaluation import PolicyEvaluator

super_domain = JSONLookupDomain(name="ImsLecturers")

policy = HandcraftedPolicy(domain=super_domain)

# Allows you to track training progress using tensorboard
summary_writer = SummaryWriter(os.path.join('logs', "tutorial"))

# logs summary statistics for each train/test epoch
logger = DiasysLogger(console_log_lvl=LogLevel.RESULTS,
                      file_log_lvl=LogLevel.DIALOGS)
dialogue_logger = DiasysLogger(name='dialogue_logger',
                               console_log_lvl=LogLevel.ERRORS,
                               file_log_lvl=LogLevel.DIALOGS,
                               logfile_folder='dialogue_history',
                               logfile_basename='history')

# Create RL policy instance with parameters used in ADVISER paper
policy = DQNPolicy(domain=super_domain,
                   lr=0.0001,
                   eps_start=0.3,
                   gradient_clipping=5.0,
                   buffer_cls=NaivePrioritizedBuffer,
                   replay_buffer_size=8192,
                   shared_layer_sizes=[256],
                   train_dialogs=1000,