예제 #1
0
파일: model.py 프로젝트: luweishuang/tatk
    def __init__(self,
                 name,
                 args,
                 sel_args,
                 train=False,
                 diverse=False,
                 max_total_len=100):
        self.config_path = os.path.join(
            os.path.split(os.path.realpath(__file__))[0], 'configs')
        self.data_path = os.path.join(get_root_path(), args.data)
        domain = get_domain(args.domain)
        corpus = RnnModel.corpus_ty(domain,
                                    self.data_path,
                                    freq_cutoff=args.unk_threshold,
                                    verbose=True,
                                    sep_sel=args.sep_sel)

        model = RnnModel(corpus.word_dict, corpus.item_dict_old,
                         corpus.context_dict, corpus.count_dict, args)
        state_dict = utils.load_model(
            os.path.join(self.config_path, args.model_file))  # RnnModel
        model.load_state_dict(state_dict)

        sel_model = SelectionModel(corpus.word_dict, corpus.item_dict_old,
                                   corpus.context_dict, corpus.count_dict,
                                   sel_args)
        sel_state_dict = utils.load_model(
            os.path.join(self.config_path, sel_args.selection_model_file))
        sel_model.load_state_dict(sel_state_dict)

        super(DealornotAgent, self).__init__(model, sel_model, args, name,
                                             train, diverse, max_total_len)
        self.vis = args.visual
예제 #2
0
파일: model.py 프로젝트: zqwerty/tatk
    def __init__(
        self,
        name,
        args,
        sel_args,
        train=False,
        diverse=False,
        max_total_len=100,
        model_url='https://tatk-data.s3-ap-northeast-1.amazonaws.com/rnnrollout_dealornot.zip'
    ):
        self.config_path = os.path.join(
            os.path.dirname(os.path.abspath(__file__)), 'configs')

        self.file_url = model_url

        self.auto_download()

        if not os.path.exists(self.config_path):
            os.mkdir(self.config_path)
        _model_path = os.path.join(self.config_path, 'models')
        self.model_path = _model_path
        if not os.path.exists(_model_path):
            os.makedirs(_model_path)

        self.data_path = os.path.join(get_root_path(), args.data)
        domain = get_domain(args.domain)
        corpus = RnnModel.corpus_ty(domain,
                                    self.data_path,
                                    freq_cutoff=args.unk_threshold,
                                    verbose=True,
                                    sep_sel=args.sep_sel)

        model = RnnModel(corpus.word_dict, corpus.item_dict_old,
                         corpus.context_dict, corpus.count_dict, args)
        state_dict = utils.load_model(
            os.path.join(self.config_path, args.model_file))  # RnnModel
        model.load_state_dict(state_dict)

        sel_model = SelectionModel(corpus.word_dict, corpus.item_dict_old,
                                   corpus.context_dict, corpus.count_dict,
                                   sel_args)
        sel_state_dict = utils.load_model(
            os.path.join(self.config_path, sel_args.selection_model_file))
        sel_model.load_state_dict(sel_state_dict)

        super(DealornotAgent, self).__init__(model, sel_model, args, name,
                                             train, diverse, max_total_len)
        self.vis = args.visual
예제 #3
0
    def __init__(self, model, args, verbose=False):
        super(LatentClusteringLanguageEngine,
              self).__init__(model, args, verbose)
        self.crit = nn.CrossEntropyLoss(reduction='sum')
        self.cluster_crit = nn.NLLLoss(reduction='sum')

        self.sel_model = utils.load_model(args.selection_model_file)
        self.sel_model.eval()
예제 #4
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(LatentClusteringPredictionModel, self).__init__()

        self.lang_model = utils.load_model(args.lang_model_file)
        self.lang_model.eval()

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.ctx_encoder = MlpContextEncoder(len(self.context_dict),
                                             domain.input_length(),
                                             args.nembed_ctx, args.nhid_ctx,
                                             args.dropout, args.init_range,
                                             False)

        self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word)

        self.encoder = nn.GRU(input_size=args.nembed_word,
                              hidden_size=args.nhid_lang,
                              bias=True)

        self.embed2hid = nn.Sequential(
            nn.Linear(args.nhid_lang + args.nhid_lang + args.nhid_ctx,
                      self.args.nhid_lang), nn.Tanh())

        self.latent_bottleneck = ShardedLatentBottleneckModule(
            num_shards=len(count_dict),
            num_clusters=self.lang_model.cluster_model.args.num_clusters,
            input_size=args.nhid_lang,
            output_size=self.lang_model.cluster_model.args.nhid_cluster,
            args=args)

        # copy lat vars from the cluster model
        self.latent_bottleneck.latent_vars.weight.data.copy_(
            self.lang_model.cluster_model.latent_bottleneck.latent_vars.weight.
            data)

        self.memory = RecurrentUnit(
            input_size=args.nhid_lang,
            hidden_size=self.lang_model.cluster_model.args.nhid_cluster,
            args=args)

        self.dropout = nn.Dropout(args.dropout)

        self.kldiv = nn.KLDivLoss(reduction='sum')

        # init
        self.word_embed.weight.data.uniform_(-args.init_range, args.init_range)
        init_rnn(self.encoder, args.init_range)
        init_cont(self.embed2hid, args.init_range)
예제 #5
0
    def __init__(self, word_dict, item_dict, context_dict, count_dict, args):
        super(LatentClusteringLanguageModel, self).__init__()

        self.cluster_model = utils.load_model(args.cluster_model_file)
        self.cluster_model.eval()

        domain = get_domain(args.domain)

        self.word_dict = word_dict
        self.item_dict = item_dict
        self.context_dict = context_dict
        self.count_dict = count_dict
        self.args = args

        self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word)

        self.encoder = nn.GRU(input_size=args.nembed_word,
                              hidden_size=args.nhid_lang,
                              bias=True)

        self.hid2output = nn.Sequential(
            nn.Linear(args.nhid_lang, args.nembed_word),
            nn.Dropout(args.dropout))

        self.cond2input = nn.Linear(
            args.nhid_lang + self.cluster_model.args.nhid_cluster,
            args.nembed_word)

        self.decoder_reader = nn.GRU(input_size=args.nembed_word,
                                     hidden_size=args.nhid_lang,
                                     bias=True)

        self.decoder_writer = nn.GRUCell(input_size=args.nembed_word,
                                         hidden_size=args.nhid_lang,
                                         bias=True)

        # tie the weights between reader and writer
        self.decoder_writer.weight_ih = self.decoder_reader.weight_ih_l0
        self.decoder_writer.weight_hh = self.decoder_reader.weight_hh_l0
        self.decoder_writer.bias_ih = self.decoder_reader.bias_ih_l0
        self.decoder_writer.bias_hh = self.decoder_reader.bias_hh_l0

        self.dropout = nn.Dropout(args.dropout)

        self.special_token_mask = make_mask(len(word_dict), [
            word_dict.get_idx(w) for w in ['<unk>', 'YOU:', 'THEM:', '<pad>']
        ])

        # init
        self.word_embed.weight.data.uniform_(-args.init_range, args.init_range)
        init_rnn(self.decoder_reader, args.init_range)
        init_linear(self.cond2input, args.init_range)
        init_cont(self.hid2output, args.init_range)
        init_rnn(self.encoder, args.init_range)
예제 #6
0
    def __init__(self, model, args, verbose=False):
        super(LatentClusteringEngine, self).__init__(model, args, verbose)
        self.crit = nn.CrossEntropyLoss(reduction='sum')
        self.kldiv = nn.KLDivLoss(reduction='sum')
        self.cluster_crit = nn.NLLLoss(reduction='sum')
        self.sel_crit = Criterion(self.model.item_dict,
                                  bad_toks=['<disconnect>', '<disagree>'],
                                  reduction='mean' if args.sep_sel else 'none')

        self.sel_model = utils.load_model(args.selection_model_file)
        self.sel_model.eval()
예제 #7
0
    def __init__(self,
                 model,
                 args,
                 name='Alice',
                 allow_no_agreement=True,
                 train=False,
                 diverse=False,
                 max_dec_len=20):
        self.model = model
        self.model.eval()
        self.args = args
        self.name = name
        self.human = False
        self.domain = domain.get_domain(args.domain)
        self.allow_no_agreement = allow_no_agreement
        self.max_dec_len = max_dec_len

        self.sel_model = utils.load_model(args.selection_model_file)
        self.sel_model.eval()
        self.ncandidate = 5
        self.nrollout = 3
        self.rollout_len = 100
예제 #8
0
파일: reinforce.py 프로젝트: zyq0104/tatk
def main():
    parser = argparse.ArgumentParser(description='Reinforce')
    parser.add_argument('--alice_model_file', type=str,
        help='Alice model file')
    parser.add_argument('--bob_model_file', type=str,
        help='Bob model file')
    parser.add_argument('--output_model_file', type=str,
        help='output model file')
    parser.add_argument('--context_file', type=str,
        help='context file')
    parser.add_argument('--temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--pred_temperature', type=float, default=1.0,
        help='temperature')
    parser.add_argument('--cuda', action='store_true', default=False,
        help='use CUDA')
    parser.add_argument('--verbose', action='store_true', default=False,
        help='print out converations')
    parser.add_argument('--seed', type=int, default=1,
        help='random seed')
    parser.add_argument('--score_threshold', type=int, default=6,
        help='successful dialog should have more than score_threshold in score')
    parser.add_argument('--log_file', type=str, default='',
        help='log successful dialogs to file for training')
    parser.add_argument('--smart_bob', action='store_true', default=False,
        help='make Bob smart again')
    parser.add_argument('--gamma', type=float, default=0.99,
        help='discount factor')
    parser.add_argument('--eps', type=float, default=0.5,
        help='eps greedy')
    parser.add_argument('--momentum', type=float, default=0.1,
        help='momentum for sgd')
    parser.add_argument('--lr', type=float, default=0.1,
        help='learning rate')
    parser.add_argument('--clip', type=float, default=0.1,
        help='gradient clip')
    parser.add_argument('--rl_lr', type=float, default=0.002,
        help='RL learning rate')
    parser.add_argument('--rl_clip', type=float, default=2.0,
        help='RL gradient clip')
    parser.add_argument('--ref_text', type=str,
        help='file with the reference text')
    parser.add_argument('--sv_train_freq', type=int, default=-1,
        help='supervision train frequency')
    parser.add_argument('--nepoch', type=int, default=1,
        help='number of epochs')
    parser.add_argument('--hierarchical', action='store_true', default=False,
        help='use hierarchical training')
    parser.add_argument('--visual', action='store_true', default=False,
        help='plot graphs')
    parser.add_argument('--domain', type=str, default='object_division',
        help='domain for the dialogue')
    parser.add_argument('--selection_model_file', type=str,  default='',
        help='path to save the final model')
    parser.add_argument('--data', type=str, default='data/negotiate',
        help='location of the data corpus')
    parser.add_argument('--unk_threshold', type=int, default=20,
        help='minimum word frequency to be in dictionary')
    parser.add_argument('--bsz', type=int, default=16,
        help='batch size')
    parser.add_argument('--validate', action='store_true', default=False,
        help='plot graphs')
    parser.add_argument('--scratch', action='store_true', default=False,
        help='erase prediciton weights')
    parser.add_argument('--sep_sel', action='store_true', default=False,
        help='use separate classifiers for selection')

    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    alice_model = utils.load_model(args.alice_model_file)  # RnnModel
    alice_ty = get_agent_type(alice_model)  # RnnRolloutAgent
    alice = alice_ty(alice_model, args, name='Alice', train=True)
    alice.vis = args.visual

    bob_model = utils.load_model(args.bob_model_file)  # RnnModel
    bob_ty = get_agent_type(bob_model)  # RnnAgent
    bob = bob_ty(bob_model, args, name='Bob', train=False)

    dialog = Dialog([alice, bob], args)
    logger = DialogLogger(verbose=args.verbose, log_file=args.log_file)
    ctx_gen = ContextGenerator(args.context_file)

    domain = get_domain(args.domain)
    corpus = alice_model.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold,
        verbose=True, sep_sel=args.sep_sel)
    engine = alice_model.engine_ty(alice_model, args)

    reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger)
    reinforce.run()

    utils.save_model(alice.model, args.output_model_file)
예제 #9
0
def main():
    parser = argparse.ArgumentParser(description='chat utility')
    parser.add_argument('--model_file', type=str, help='model file')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--context_file',
                        type=str,
                        default='',
                        help='context file')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--num_types',
                        type=int,
                        default=3,
                        help='number of object types')
    parser.add_argument('--num_objects',
                        type=int,
                        default=6,
                        help='total number of objects')
    parser.add_argument('--max_score',
                        type=int,
                        default=10,
                        help='max score per object')
    parser.add_argument(
        '--score_threshold',
        type=int,
        default=6,
        help='successful dialog should have more than score_threshold in score'
    )
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--smart_ai',
                        action='store_true',
                        default=False,
                        help='make AI smart again')
    parser.add_argument('--ai_starts',
                        action='store_true',
                        default=False,
                        help='allow AI to start the dialog')
    parser.add_argument('--ref_text',
                        type=str,
                        help='file with the reference text')
    parser.add_argument('--cuda',
                        action='store_true',
                        default=False,
                        help='use CUDA')
    args = parser.parse_args()

    utils.use_cuda(args.cuda)
    utils.set_seed(args.seed)

    human = HumanAgent(domain.get_domain(args.domain))

    alice_ty = RnnRolloutAgent if args.smart_ai else HierarchicalAgent
    ai = alice_ty(utils.load_model(args.model_file), args)

    agents = [ai, human] if args.ai_starts else [human, ai]

    dialog = Dialog(agents, args)
    logger = DialogLogger(verbose=True)
    if args.context_file == '':
        ctx_gen = ManualContextGenerator(args.num_types, args.num_objects,
                                         args.max_score)
    else:
        ctx_gen = ContextGenerator(args.context_file)

    chat = Chat(dialog, ctx_gen, logger)
    chat.run()
예제 #10
0
파일: avg_rank.py 프로젝트: zz-jacob/tatk
def main():
    parser = argparse.ArgumentParser(description='Negotiator')
    parser.add_argument('--dataset',
                        type=str,
                        default='./data/negotiate/val.txt',
                        help='location of the dataset')
    parser.add_argument('--model_file', type=str, help='model file')
    parser.add_argument('--smart_ai',
                        action='store_true',
                        default=False,
                        help='to use rollouts')
    parser.add_argument('--seed', type=int, default=1, help='random seed')
    parser.add_argument('--temperature',
                        type=float,
                        default=1.0,
                        help='temperature')
    parser.add_argument('--domain',
                        type=str,
                        default='object_division',
                        help='domain for the dialogue')
    parser.add_argument('--log_file', type=str, default='', help='log file')
    args = parser.parse_args()

    utils.set_seed(args.seed)

    model = utils.load_model(args.model_file)
    ai = LstmAgent(model, args)
    logger = DialogLogger(verbose=True, log_file=args.log_file)
    domain = get_domain(args.domain)

    score_func = rollout if args.smart_ai else likelihood

    dataset, sents = read_dataset(args.dataset)
    ranks, n, k = 0, 0, 0
    for ctx, dialog in dataset:
        start_time = time.time()
        # start new conversation
        ai.feed_context(ctx)
        for sent, you in dialog:
            if you:
                # if it is your turn to say, take the target word and compute its rank
                rank = compute_rank(sent, sents, ai, domain, args.temperature,
                                    score_func)
                # compute lang_h for the groundtruth sentence
                enc = ai._encode(sent, ai.model.word_dict)
                _, ai.lang_h, lang_hs = ai.model.score_sent(
                    enc, ai.lang_h, ai.ctx_h, args.temperature)
                # save hidden states and the utterance
                ai.lang_hs.append(lang_hs)
                ai.words.append(ai.model.word2var('YOU:'))
                ai.words.append(Variable(enc))
                ranks += rank
                n += 1
            else:
                ai.read(sent)
        k += 1
        time_elapsed = time.time() - start_time
        logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' %
                    (k, 1. * ranks / n, ranks, n, time_elapsed))

    logger.dump('final avg rank %.3f' % (1. * ranks / n))