Пример #1
0
    def __init__(self,
                 model,
                 sel_model,
                 args,
                 name='Alice',
                 train=False,
                 diverse=False,
                 max_total_len=100):
        """Constructor of RNNRollout model."""
        self.model = model
        self.model.eval()
        self.args = args
        self.name = name
        self.human = False
        self.domain = domain.get_domain(args.domain)
        self.allow_no_agreement = True

        root_path = os.path.dirname(os.path.abspath(__file__))
        self.sel_model = sel_model
        self.sel_model.eval()

        self.ncandidate = 5
        self.nrollout = 3
        self.rollout_len = 100
        self.max_total_len = max_total_len

        self.__current_len = self.max_total_len
Пример #2
0
    def __init__(self,
                 name,
                 args,
                 sel_args,
                 train=False,
                 diverse=False,
                 max_total_len=100):
        self.config_path = os.path.join(
            os.path.split(os.path.realpath(__file__))[0], 'configs')
        self.data_path = os.path.join(get_root_path(), args.data)
        domain = get_domain(args.domain)
        corpus = RnnModel.corpus_ty(domain,
                                    self.data_path,
                                    freq_cutoff=args.unk_threshold,
                                    verbose=True,
                                    sep_sel=args.sep_sel)

        model = RnnModel(corpus.word_dict, corpus.item_dict_old,
                         corpus.context_dict, corpus.count_dict, args)
        state_dict = utils.load_model(
            os.path.join(self.config_path, args.model_file))  # RnnModel
        model.load_state_dict(state_dict)

        sel_model = SelectionModel(corpus.word_dict, corpus.item_dict_old,
                                   corpus.context_dict, corpus.count_dict,
                                   sel_args)
        sel_state_dict = utils.load_model(
            os.path.join(self.config_path, sel_args.selection_model_file))
        sel_model.load_state_dict(sel_state_dict)

        super(DealornotAgent, self).__init__(model, sel_model, args, name,
                                             train, diverse, max_total_len)
        self.vis = args.visual
Пример #3
0
 def __init__(self, agents, args):
     # For now we only suppport dialog of 2 agents
     assert len(agents) == 2
     self.agents = agents
     self.args = args
     self.domain = domain.get_domain(args.domain)
     self.metrics = MetricsContainer()
     self._register_metrics()
Пример #4
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(LatentClusteringPredictionModel, self).__init__()

        self.lang_model = utils.load_model(args.lang_model_file)
        self.lang_model.eval()

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.ctx_encoder = MlpContextEncoder(len(self.context_dict),
                                             domain.input_length(),
                                             args.nembed_ctx, args.nhid_ctx,
                                             args.dropout, args.init_range,
                                             False)

        self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word)

        self.encoder = nn.GRU(input_size=args.nembed_word,
                              hidden_size=args.nhid_lang,
                              bias=True)

        self.embed2hid = nn.Sequential(
            nn.Linear(args.nhid_lang + args.nhid_lang + args.nhid_ctx,
                      self.args.nhid_lang), nn.Tanh())

        self.latent_bottleneck = ShardedLatentBottleneckModule(
            num_shards=len(count_dict),
            num_clusters=self.lang_model.cluster_model.args.num_clusters,
            input_size=args.nhid_lang,
            output_size=self.lang_model.cluster_model.args.nhid_cluster,
            args=args)

        # copy lat vars from the cluster model
        self.latent_bottleneck.latent_vars.weight.data.copy_(
            self.lang_model.cluster_model.latent_bottleneck.latent_vars.weight.
            data)

        self.memory = RecurrentUnit(
            input_size=args.nhid_lang,
            hidden_size=self.lang_model.cluster_model.args.nhid_cluster,
            args=args)

        self.dropout = nn.Dropout(args.dropout)

        self.kldiv = nn.KLDivLoss(reduction='sum')

        # init
        self.word_embed.weight.data.uniform_(-args.init_range, args.init_range)
        init_rnn(self.encoder, args.init_range)
        init_cont(self.embed2hid, args.init_range)
Пример #5
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(LatentClusteringLanguageModel, self).__init__()

        self.cluster_model = utils.load_model(args.cluster_model_file)
        self.cluster_model.eval()

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word)

        self.encoder = nn.GRU(input_size=args.nembed_word,
                              hidden_size=args.nhid_lang,
                              bias=True)

        self.hid2output = nn.Sequential(
            nn.Linear(args.nhid_lang, args.nembed_word),
            nn.Dropout(args.dropout))

        self.cond2input = nn.Linear(
            args.nhid_lang + self.cluster_model.args.nhid_cluster,
            args.nembed_word)

        self.decoder_reader = nn.GRU(input_size=args.nembed_word,
                                     hidden_size=args.nhid_lang,
                                     bias=True)

        self.decoder_writer = nn.GRUCell(input_size=args.nembed_word,
                                         hidden_size=args.nhid_lang,
                                         bias=True)

        # tie the weights between reader and writer
        self.decoder_writer.weight_ih = self.decoder_reader.weight_ih_l0
        self.decoder_writer.weight_hh = self.decoder_reader.weight_hh_l0
        self.decoder_writer.bias_ih = self.decoder_reader.bias_ih_l0
        self.decoder_writer.bias_hh = self.decoder_reader.bias_hh_l0

        self.dropout = nn.Dropout(args.dropout)

        self.special_token_mask = make_mask(len(word_dict), [
            word_dict.get_idx(w) for w in ['<unk>', 'YOU:', 'THEM:', '<pad>']
        ])

        # init
        self.word_embed.weight.data.uniform_(-args.init_range, args.init_range)
        init_rnn(self.decoder_reader, args.init_range)
        init_linear(self.cond2input, args.init_range)
        init_cont(self.hid2output, args.init_range)
        init_rnn(self.encoder, args.init_range)
Пример #6
0
    def __init__(
        self,
        name,
        args,
        sel_args,
        train=False,
        diverse=False,
        max_total_len=100,
        model_url='https://tatk-data.s3-ap-northeast-1.amazonaws.com/rnnrollout_dealornot.zip'
    ):
        self.config_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'configs')

        self.file_url = model_url

        self.auto_download()

        if not os.path.exists(self.config_path):
            os.mkdir(self.config_path)
        _model_path = os.path.join(self.config_path, 'models')
        self.model_path = _model_path
        if not os.path.exists(_model_path):
            os.makedirs(_model_path)

        self.data_path = os.path.join(get_root_path(), args.data)
        domain = get_domain(args.domain)
        corpus = RnnModel.corpus_ty(domain,
                                    self.data_path,
                                    freq_cutoff=args.unk_threshold,
                                    verbose=True,
                                    sep_sel=args.sep_sel)

        model = RnnModel(corpus.word_dict, corpus.item_dict_old,
                         corpus.context_dict, corpus.count_dict, args)
        state_dict = utils.load_model(
            os.path.join(self.config_path, args.model_file))  # RnnModel
        model.load_state_dict(state_dict)

        sel_model = SelectionModel(corpus.word_dict, corpus.item_dict_old,
                                   corpus.context_dict, corpus.count_dict,
                                   sel_args)
        sel_state_dict = utils.load_model(
            os.path.join(self.config_path, sel_args.selection_model_file))
        sel_model.load_state_dict(sel_state_dict)

        super(DealornotAgent, self).__init__(model, sel_model, args, name,
                                             train, diverse, max_total_len)
        self.vis = args.visual
Пример #7
0
def main():
    parser = argparse.ArgumentParser(
        description='A script to compute Pareto efficiency')
    parser.add_argument('--log_file', type=str, default='',
        help='location of the log file')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')

    args = parser.parse_args()
    domain = get_domain(args.domain)

    dataset = parse_log(args.log_file, domain)

    avg_agree, avg_can_improve = 0, 0
    avg_score1, avg_score2 = 0, 0
    avg_max_score1, avg_max_score2 = 0, 0
    for cnts, vals1, picks1, vals2, picks2 in dataset:
        if np.min(picks1) == -1 or np.min(picks2) == -1:
            continue
        agree = True
        for p1, p2, n in zip(picks1, picks2, cnts):
            agree = agree and (p1 + p2 == n)
        if not agree:
            continue

        avg_agree += 1
        score1 = compute_score(vals1, picks1)
        score2 = compute_score(vals2, picks2)
        choices = gen_choices(cnts)
        can_improve = False
        for cand1, cand2 in choices:
            cand_score1 = compute_score(vals1, cand1)
            cand_score2 = compute_score(vals2, cand2)
            if (cand_score1 > score1 and cand_score2 >= score2) or (cand_score1 >= score1 and cand_score2 > score2):
                can_improve = True

        avg_score1 += score1
        avg_score2 += score2
        avg_can_improve += int(can_improve)

    print('pareto opt (%%)\t:\t%.2f' %  (100. * (1 - avg_can_improve / avg_agree)))
    print('agree (%%)\t:\t%.2f' % (100. * avg_agree / len(dataset)))
    print('score (all)\t:\t%.2f vs. %.2f' % (
        1. * avg_score1 / len(dataset), 1. * avg_score2 / len(dataset)))
    print('score (agreed)\t:\t%.2f vs. %.2f' % (
        1. * avg_score1 / avg_agree, 1. * avg_score2 / avg_agree))
Пример #8
0
    def __init__(self,
                 model,
                 args,
                 name='Alice',
                 allow_no_agreement=True,
                 train=False,
                 diverse=False,
                 max_dec_len=20):
        self.model = model
        self.model.eval()
        self.args = args
        self.name = name
        self.human = False
        self.domain = domain.get_domain(args.domain)
        self.allow_no_agreement = allow_no_agreement
        self.max_dec_len = max_dec_len

        self.sel_model = utils.load_model(args.selection_model_file)
        self.sel_model.eval()
        self.ncandidate = 5
        self.nrollout = 3
        self.rollout_len = 100
Пример #9
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(SelectionModel, self).__init__()

        self.nhid_pos = 32
        self.nhid_speaker = 32
        self.len_cutoff = 10

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word)
        self.pos_encoder = nn.Embedding(self.len_cutoff, self.nhid_pos)
        self.speaker_encoder = nn.Embedding(len(self.word_dict), self.nhid_speaker)
        self.ctx_encoder = MlpContextEncoder(len(self.context_dict), domain.input_length(),
            args.nembed_ctx, args.nhid_ctx, args.dropout, args.init_range, args.skip_values)

        self.sel_head = SelectionModule(
            query_size=args.nhid_ctx,
            value_size=args.nembed_word + self.nhid_pos + self.nhid_speaker,
            hidden_size=args.nhid_attn,
            selection_size=args.nhid_sel,
            num_heads=6,
            output_size=len(item_dict),
            args=args)

        self.dropout = nn.Dropout(args.dropout)

        # init embeddings
        self.word_encoder.weight.data.uniform_(-self.args.init_range, self.args.init_range)
        self.pos_encoder.weight.data.uniform_(-self.args.init_range, self.args.init_range)
        self.speaker_encoder.weight.data.uniform_(-self.args.init_range, self.args.init_range)
Пример #10
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(LatentClusteringModel, self).__init__()

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.ctx_encoder = MlpContextEncoder(len(self.context_dict),
                                             domain.input_length(),
                                             args.nembed_ctx, args.nhid_ctx,
                                             args.dropout, args.init_range,
                                             args.skip_values)

        self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word)

        self.hid2output = nn.Sequential(
            nn.Linear(args.nhid_lang, args.nembed_word),
            nn.Dropout(args.dropout))

        self.mem2input = nn.Linear(args.nhid_lang, args.nembed_word)

        self.encoder = nn.GRU(input_size=args.nembed_word,
                              hidden_size=args.nhid_lang,
                              bias=True)

        self.embed2hid = nn.Sequential(
            nn.Linear(args.nhid_lang + args.nhid_lang + args.nhid_ctx,
                      args.nhid_cluster), nn.Tanh())

        self.decoder_reader = nn.GRU(input_size=args.nembed_word,
                                     hidden_size=args.nhid_lang,
                                     bias=True)

        self.decoder_writer = nn.GRUCell(input_size=args.nembed_word,
                                         hidden_size=args.nhid_lang,
                                         bias=True)

        # tie the weights between reader and writer
        self.decoder_writer.weight_ih = self.decoder_reader.weight_ih_l0
        self.decoder_writer.weight_hh = self.decoder_reader.weight_hh_l0
        self.decoder_writer.bias_ih = self.decoder_reader.bias_ih_l0
        self.decoder_writer.bias_hh = self.decoder_reader.bias_hh_l0

        self.latent_bottleneck = ShardedLatentBottleneckModule(
            num_shards=len(count_dict),
            num_clusters=args.num_clusters,
            input_size=args.nhid_lang,
            output_size=args.nhid_cluster,
            args=args)

        self.memory = nn.GRUCell(input_size=args.nhid_cluster,
                                 hidden_size=args.nhid_lang,
                                 bias=True)

        self.dropout = nn.Dropout(args.dropout)

        self.selection = SimpleSeparateSelectionModule(
            input_size=args.nhid_cluster,
            hidden_size=args.nhid_sel,
            output_size=len(item_dict),
            args=args)

        # init
        self.word_embed.weight.data.uniform_(-args.init_range, args.init_range)
        init_rnn(self.encoder, args.init_range)
        init_rnn(self.decoder_reader, args.init_range)
        init_rnn_cell(self.memory, args.init_range)
        init_linear(self.mem2input, args.init_range)
        init_cont(self.hid2output, args.init_range)
        init_cont(self.embed2hid, args.init_range)
Пример #11
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(BaselineClusteringModel, self).__init__()

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.ctx_encoder = MlpContextEncoder(len(self.context_dict),
                                             domain.input_length(),
                                             args.nembed_ctx, args.nhid_lang,
                                             args.dropout, args.init_range,
                                             False)

        self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word)

        self.encoder = nn.GRU(input_size=args.nembed_word,
                              hidden_size=args.nhid_lang,
                              bias=True)

        self.latent_bottleneck = ShardedLatentBottleneckModule(
            num_shards=len(count_dict),
            num_clusters=self.args.num_clusters,
            input_size=args.nhid_lang,
            output_size=self.args.nhid_cluster,
            args=args)

        self.dropout = nn.Dropout(args.dropout)

        self.decoder_reader = nn.GRU(input_size=args.nembed_word,
                                     hidden_size=args.nhid_lang,
                                     bias=True)

        self.decoder_writer = nn.GRUCell(input_size=args.nembed_word,
                                         hidden_size=args.nhid_lang,
                                         bias=True)

        self.cond2input = nn.Linear(args.nhid_cluster, args.nembed_word)

        self.hid2output = nn.Sequential(
            nn.Linear(args.nhid_lang, args.nembed_word),
            nn.Dropout(args.dropout))

        self.memory = RecurrentUnit(input_size=args.nhid_lang,
                                    hidden_size=args.nhid_lang,
                                    args=args)

        # tie the weights between reader and writer
        self.decoder_writer.weight_ih = self.decoder_reader.weight_ih_l0
        self.decoder_writer.weight_hh = self.decoder_reader.weight_hh_l0
        self.decoder_writer.bias_ih = self.decoder_reader.bias_ih_l0
        self.decoder_writer.bias_hh = self.decoder_reader.bias_hh_l0

        self.special_token_mask = make_mask(len(word_dict), [
            word_dict.get_idx(w) for w in ['<unk>', 'YOU:', 'THEM:', '<pad>']
        ])

        # init
        self.word_embed.weight.data.uniform_(-args.init_range, args.init_range)
        init_rnn(self.encoder, args.init_range)
        init_rnn(self.decoder_reader, args.init_range)
        init_linear(self.cond2input, args.init_range)
        init_cont(self.hid2output, args.init_range)
Пример #12
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(RnnModel, self).__init__()

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.word_encoder = nn.Embedding(len(self.word_dict), args.nembed_word)
        self.word_encoder_dropout = nn.Dropout(args.dropout)

        ctx_encoder_ty = MlpContextEncoder
        self.ctx_encoder = nn.Sequential(
            ctx_encoder_ty(len(self.context_dict), domain.input_length(),
                           args.nembed_ctx, args.nhid_ctx, args.dropout,
                           args.init_range), nn.Dropout(args.dropout))

        self.reader = nn.GRU(args.nhid_ctx + args.nembed_word,
                             args.nhid_lang,
                             bias=True)
        self.reader_dropout = nn.Dropout(args.dropout)

        self.decoder = nn.Sequential(
            nn.Linear(args.nhid_lang, args.nembed_word),
            nn.Dropout(args.dropout))

        self.writer = nn.GRUCell(input_size=args.nhid_ctx + args.nembed_word,
                                 hidden_size=args.nhid_lang,
                                 bias=True)

        # Tie the weights of reader and writer
        self.writer.weight_ih = self.reader.weight_ih_l0
        self.writer.weight_hh = self.reader.weight_hh_l0
        self.writer.bias_ih = self.reader.bias_ih_l0
        self.writer.bias_hh = self.reader.bias_hh_l0

        self.sel_rnn = nn.GRU(input_size=args.nhid_lang + args.nembed_word,
                              hidden_size=args.nhid_attn,
                              bias=True,
                              bidirectional=True)
        self.sel_dropout = nn.Dropout(args.dropout)

        # Mask for disabling special tokens when generating sentences
        self.special_token_mask = torch.FloatTensor(len(self.word_dict))

        self.sel_encoder = nn.Sequential(
            torch.nn.Linear(2 * args.nhid_attn + args.nhid_ctx, args.nhid_sel),
            nn.Tanh(), nn.Dropout(args.dropout))
        self.attn = nn.Sequential(
            torch.nn.Linear(2 * args.nhid_attn, args.nhid_attn), nn.Tanh(),
            torch.nn.Linear(args.nhid_attn, 1))
        self.sel_decoders = nn.ModuleList()
        for i in range(domain.selection_length()):
            self.sel_decoders.append(
                nn.Linear(args.nhid_sel, len(self.item_dict)))

        self.init_weights()

        self.special_token_mask = make_mask(len(word_dict), [
            word_dict.get_idx(w) for w in ['<unk>', 'YOU:', 'THEM:', '<pad>']
        ])
Пример #13
0
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--alice_model_file', type=str,
        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str,
        help='Bob model file')
    parser.add_argument('--output_model_file', type=str,
        help='output model file')
    parser.add_argument('--context_file', type=str,
        help='context file')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--pred_temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--cuda', action='store_true', default=False,
        help='use CUDA')
    parser.add_argument('--verbose', action='store_true', default=False,
        help='print out converations')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--score_threshold', type=int, default=6,
        help='successful dialog should have more than score_threshold in score')
    parser.add_argument('--log_file', type=str, default='',
        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob', action='store_true', default=False,
        help='make Bob smart again')
    parser.add_argument('--gamma', type=float, default=0.99,
        help='discount factor')
    parser.add_argument('--eps', type=float, default=0.5,
        help='eps greedy')
    parser.add_argument('--momentum', type=float, default=0.1,
        help='momentum for sgd')
    parser.add_argument('--lr', type=float, default=0.1,
        help='learning rate')
    parser.add_argument('--clip', type=float, default=0.1,
        help='gradient clip')
    parser.add_argument('--rl_lr', type=float, default=0.002,
        help='RL learning rate')
    parser.add_argument('--rl_clip', type=float, default=2.0,
        help='RL gradient clip')
    parser.add_argument('--ref_text', type=str,
        help='file with the reference text')
    parser.add_argument('--sv_train_freq', type=int, default=-1,
        help='supervision train frequency')
    parser.add_argument('--nepoch', type=int, default=1,
        help='number of epochs')
    parser.add_argument('--hierarchical', action='store_true', default=False,
        help='use hierarchical training')
    parser.add_argument('--visual', action='store_true', default=False,
        help='plot graphs')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--selection_model_file', type=str,  default='',
        help='path to save the final model')
    parser.add_argument('--data', type=str, default='data/negotiate',
        help='location of the data corpus')
    parser.add_argument('--unk_threshold', type=int, default=20,
        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16,
        help='batch size')
    parser.add_argument('--validate', action='store_true', default=False,
        help='plot graphs')
    parser.add_argument('--scratch', action='store_true', default=False,
        help='erase prediciton weights')
    parser.add_argument('--sep_sel', action='store_true', default=False,
        help='use separate classifiers for selection')

    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)  # RnnModel
    alice_ty = get_agent_type(alice_model)  # RnnRolloutAgent
    alice = alice_ty(alice_model, args, name='Alice', train=True)
    alice.vis = args.visual

    bob_model = utils.load_model(args.bob_model_file)  # RnnModel
    bob_ty = get_agent_type(bob_model)  # RnnAgent
    bob = bob_ty(bob_model, args, name='Bob', train=False)

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    domain = get_domain(args.domain)
    corpus = alice_model.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold,
        verbose=True, sep_sel=args.sep_sel)
    engine = alice_model.engine_ty(alice_model, args)

    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    utils.save_model(alice.model, args.output_model_file)
Пример #14
0
def main():
    parser = argparse.ArgumentParser(description='chat utility')
    parser.add_argument('--model_file', type=str, help='model file')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--context_file',
                        type=str,
                        default='',
                        help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--num_types',
                        type=int,
                        default=3,
                        help='number of object types')
    parser.add_argument('--num_objects',
                        type=int,
                        default=6,
                        help='total number of objects')
    parser.add_argument('--max_score',
                        type=int,
                        default=10,
                        help='max score per object')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--smart_ai',
                        action='store_true',
                        default=False,
                        help='make AI smart again')
    parser.add_argument('--ai_starts',
                        action='store_true',
                        default=False,
                        help='allow AI to start the dialog')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    human = HumanAgent(domain.get_domain(args.domain))

    alice_ty = RnnRolloutAgent if args.smart_ai else HierarchicalAgent
    ai = alice_ty(utils.load_model(args.model_file), args)

    agents = [ai, human] if args.ai_starts else [human, ai]

    dialog = Dialog(agents, args)
    logger = DialogLogger(verbose=True)
    if args.context_file == '':
        ctx_gen = ManualContextGenerator(args.num_types, args.num_objects,
                                         args.max_score)
    else:
        ctx_gen = ContextGenerator(args.context_file)

    chat = Chat(dialog, ctx_gen, logger)
    chat.run()
Пример #15
0
def main():
    parser = argparse.ArgumentParser(description='Negotiator')
    parser.add_argument('--dataset',
                        type=str,
                        default='./data/negotiate/val.txt',
                        help='location of the dataset')
    parser.add_argument('--model_file', type=str, help='model file')
    parser.add_argument('--smart_ai',
                        action='store_true',
                        default=False,
                        help='to use rollouts')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--log_file', type=str, default='', help='log file')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    model = utils.load_model(args.model_file)
    ai = LstmAgent(model, args)
    logger = DialogLogger(verbose=True, log_file=args.log_file)
    domain = get_domain(args.domain)

    score_func = rollout if args.smart_ai else likelihood

    dataset, sents = read_dataset(args.dataset)
    ranks, n, k = 0, 0, 0
    for ctx, dialog in dataset:
        start_time = time.time()
        # start new conversation
        ai.feed_context(ctx)
        for sent, you in dialog:
            if you:
                # if it is your turn to say, take the target word and compute its rank
                rank = compute_rank(sent, sents, ai, domain, args.temperature,
                                    score_func)
                # compute lang_h for the groundtruth sentence
                enc = ai._encode(sent, ai.model.word_dict)
                _, ai.lang_h, lang_hs = ai.model.score_sent(
                    enc, ai.lang_h, ai.ctx_h, args.temperature)
                # save hidden states and the utterance
                ai.lang_hs.append(lang_hs)
                ai.words.append(ai.model.word2var('YOU:'))
                ai.words.append(Variable(enc))
                ranks += rank
                n += 1
            else:
                ai.read(sent)
        k += 1
        time_elapsed = time.time() - start_time
        logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' %
                    (k, 1. * ranks / n, ranks, n, time_elapsed))

    logger.dump('final avg rank %.3f' % (1. * ranks / n))