def __init__(self, name, args, sel_args, train=False, diverse=False, max_total_len=100): self.config_path = os.path.join( os.path.split(os.path.realpath(__file__))[0], 'configs') self.data_path = os.path.join(get_root_path(), args.data) domain = get_domain(args.domain) corpus = RnnModel.corpus_ty(domain, self.data_path, freq_cutoff=args.unk_threshold, verbose=True, sep_sel=args.sep_sel) model = RnnModel(corpus.word_dict, corpus.item_dict_old, corpus.context_dict, corpus.count_dict, args) state_dict = utils.load_model( os.path.join(self.config_path, args.model_file)) # RnnModel model.load_state_dict(state_dict) sel_model = SelectionModel(corpus.word_dict, corpus.item_dict_old, corpus.context_dict, corpus.count_dict, sel_args) sel_state_dict = utils.load_model( os.path.join(self.config_path, sel_args.selection_model_file)) sel_model.load_state_dict(sel_state_dict) super(DealornotAgent, self).__init__(model, sel_model, args, name, train, diverse, max_total_len) self.vis = args.visual
def __init__( self, name, args, sel_args, train=False, diverse=False, max_total_len=100, model_url='https://tatk-data.s3-ap-northeast-1.amazonaws.com/rnnrollout_dealornot.zip' ): self.config_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), 'configs') self.file_url = model_url self.auto_download() if not os.path.exists(self.config_path): os.mkdir(self.config_path) _model_path = os.path.join(self.config_path, 'models') self.model_path = _model_path if not os.path.exists(_model_path): os.makedirs(_model_path) self.data_path = os.path.join(get_root_path(), args.data) domain = get_domain(args.domain) corpus = RnnModel.corpus_ty(domain, self.data_path, freq_cutoff=args.unk_threshold, verbose=True, sep_sel=args.sep_sel) model = RnnModel(corpus.word_dict, corpus.item_dict_old, corpus.context_dict, corpus.count_dict, args) state_dict = utils.load_model( os.path.join(self.config_path, args.model_file)) # RnnModel model.load_state_dict(state_dict) sel_model = SelectionModel(corpus.word_dict, corpus.item_dict_old, corpus.context_dict, corpus.count_dict, sel_args) sel_state_dict = utils.load_model( os.path.join(self.config_path, sel_args.selection_model_file)) sel_model.load_state_dict(sel_state_dict) super(DealornotAgent, self).__init__(model, sel_model, args, name, train, diverse, max_total_len) self.vis = args.visual
def __init__(self, model, args, verbose=False): super(LatentClusteringLanguageEngine, self).__init__(model, args, verbose) self.crit = nn.CrossEntropyLoss(reduction='sum') self.cluster_crit = nn.NLLLoss(reduction='sum') self.sel_model = utils.load_model(args.selection_model_file) self.sel_model.eval()
def __init__(self, word_dict, item_dict, context_dict, count_dict, args): super(LatentClusteringPredictionModel, self).__init__() self.lang_model = utils.load_model(args.lang_model_file) self.lang_model.eval() domain = get_domain(args.domain) self.word_dict = word_dict self.item_dict = item_dict self.context_dict = context_dict self.count_dict = count_dict self.args = args self.ctx_encoder = MlpContextEncoder(len(self.context_dict), domain.input_length(), args.nembed_ctx, args.nhid_ctx, args.dropout, args.init_range, False) self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word) self.encoder = nn.GRU(input_size=args.nembed_word, hidden_size=args.nhid_lang, bias=True) self.embed2hid = nn.Sequential( nn.Linear(args.nhid_lang + args.nhid_lang + args.nhid_ctx, self.args.nhid_lang), nn.Tanh()) self.latent_bottleneck = ShardedLatentBottleneckModule( num_shards=len(count_dict), num_clusters=self.lang_model.cluster_model.args.num_clusters, input_size=args.nhid_lang, output_size=self.lang_model.cluster_model.args.nhid_cluster, args=args) # copy lat vars from the cluster model self.latent_bottleneck.latent_vars.weight.data.copy_( self.lang_model.cluster_model.latent_bottleneck.latent_vars.weight. data) self.memory = RecurrentUnit( input_size=args.nhid_lang, hidden_size=self.lang_model.cluster_model.args.nhid_cluster, args=args) self.dropout = nn.Dropout(args.dropout) self.kldiv = nn.KLDivLoss(reduction='sum') # init self.word_embed.weight.data.uniform_(-args.init_range, args.init_range) init_rnn(self.encoder, args.init_range) init_cont(self.embed2hid, args.init_range)
def __init__(self, word_dict, item_dict, context_dict, count_dict, args): super(LatentClusteringLanguageModel, self).__init__() self.cluster_model = utils.load_model(args.cluster_model_file) self.cluster_model.eval() domain = get_domain(args.domain) self.word_dict = word_dict self.item_dict = item_dict self.context_dict = context_dict self.count_dict = count_dict self.args = args self.word_embed = nn.Embedding(len(self.word_dict), args.nembed_word) self.encoder = nn.GRU(input_size=args.nembed_word, hidden_size=args.nhid_lang, bias=True) self.hid2output = nn.Sequential( nn.Linear(args.nhid_lang, args.nembed_word), nn.Dropout(args.dropout)) self.cond2input = nn.Linear( args.nhid_lang + self.cluster_model.args.nhid_cluster, args.nembed_word) self.decoder_reader = nn.GRU(input_size=args.nembed_word, hidden_size=args.nhid_lang, bias=True) self.decoder_writer = nn.GRUCell(input_size=args.nembed_word, hidden_size=args.nhid_lang, bias=True) # tie the weights between reader and writer self.decoder_writer.weight_ih = self.decoder_reader.weight_ih_l0 self.decoder_writer.weight_hh = self.decoder_reader.weight_hh_l0 self.decoder_writer.bias_ih = self.decoder_reader.bias_ih_l0 self.decoder_writer.bias_hh = self.decoder_reader.bias_hh_l0 self.dropout = nn.Dropout(args.dropout) self.special_token_mask = make_mask(len(word_dict), [ word_dict.get_idx(w) for w in ['<unk>', 'YOU:', 'THEM:', '<pad>'] ]) # init self.word_embed.weight.data.uniform_(-args.init_range, args.init_range) init_rnn(self.decoder_reader, args.init_range) init_linear(self.cond2input, args.init_range) init_cont(self.hid2output, args.init_range) init_rnn(self.encoder, args.init_range)
def __init__(self, model, args, verbose=False): super(LatentClusteringEngine, self).__init__(model, args, verbose) self.crit = nn.CrossEntropyLoss(reduction='sum') self.kldiv = nn.KLDivLoss(reduction='sum') self.cluster_crit = nn.NLLLoss(reduction='sum') self.sel_crit = Criterion(self.model.item_dict, bad_toks=['<disconnect>', '<disagree>'], reduction='mean' if args.sep_sel else 'none') self.sel_model = utils.load_model(args.selection_model_file) self.sel_model.eval()
def __init__(self, model, args, name='Alice', allow_no_agreement=True, train=False, diverse=False, max_dec_len=20): self.model = model self.model.eval() self.args = args self.name = name self.human = False self.domain = domain.get_domain(args.domain) self.allow_no_agreement = allow_no_agreement self.max_dec_len = max_dec_len self.sel_model = utils.load_model(args.selection_model_file) self.sel_model.eval() self.ncandidate = 5 self.nrollout = 3 self.rollout_len = 100
def main(): parser = argparse.ArgumentParser(description='Reinforce') parser.add_argument('--alice_model_file', type=str, help='Alice model file') parser.add_argument('--bob_model_file', type=str, help='Bob model file') parser.add_argument('--output_model_file', type=str, help='output model file') parser.add_argument('--context_file', type=str, help='context file') parser.add_argument('--temperature', type=float, default=1.0, help='temperature') parser.add_argument('--pred_temperature', type=float, default=1.0, help='temperature') parser.add_argument('--cuda', action='store_true', default=False, help='use CUDA') parser.add_argument('--verbose', action='store_true', default=False, help='print out converations') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--score_threshold', type=int, default=6, help='successful dialog should have more than score_threshold in score') parser.add_argument('--log_file', type=str, default='', help='log successful dialogs to file for training') parser.add_argument('--smart_bob', action='store_true', default=False, help='make Bob smart again') parser.add_argument('--gamma', type=float, default=0.99, help='discount factor') parser.add_argument('--eps', type=float, default=0.5, help='eps greedy') parser.add_argument('--momentum', type=float, default=0.1, help='momentum for sgd') parser.add_argument('--lr', type=float, default=0.1, help='learning rate') parser.add_argument('--clip', type=float, default=0.1, help='gradient clip') parser.add_argument('--rl_lr', type=float, default=0.002, help='RL learning rate') parser.add_argument('--rl_clip', type=float, default=2.0, help='RL gradient clip') parser.add_argument('--ref_text', type=str, help='file with the reference text') parser.add_argument('--sv_train_freq', type=int, default=-1, help='supervision train frequency') parser.add_argument('--nepoch', type=int, default=1, help='number of epochs') parser.add_argument('--hierarchical', action='store_true', default=False, help='use hierarchical training') parser.add_argument('--visual', action='store_true', default=False, help='plot graphs') parser.add_argument('--domain', type=str, default='object_division', help='domain for the dialogue') parser.add_argument('--selection_model_file', type=str, default='', help='path to save the final model') parser.add_argument('--data', type=str, default='data/negotiate', help='location of the data corpus') parser.add_argument('--unk_threshold', type=int, default=20, help='minimum word frequency to be in dictionary') parser.add_argument('--bsz', type=int, default=16, help='batch size') parser.add_argument('--validate', action='store_true', default=False, help='plot graphs') parser.add_argument('--scratch', action='store_true', default=False, help='erase prediciton weights') parser.add_argument('--sep_sel', action='store_true', default=False, help='use separate classifiers for selection') args = parser.parse_args() utils.use_cuda(args.cuda) utils.set_seed(args.seed) alice_model = utils.load_model(args.alice_model_file) # RnnModel alice_ty = get_agent_type(alice_model) # RnnRolloutAgent alice = alice_ty(alice_model, args, name='Alice', train=True) alice.vis = args.visual bob_model = utils.load_model(args.bob_model_file) # RnnModel bob_ty = get_agent_type(bob_model) # RnnAgent bob = bob_ty(bob_model, args, name='Bob', train=False) dialog = Dialog([alice, bob], args) logger = DialogLogger(verbose=args.verbose, log_file=args.log_file) ctx_gen = ContextGenerator(args.context_file) domain = get_domain(args.domain) corpus = alice_model.corpus_ty(domain, args.data, freq_cutoff=args.unk_threshold, verbose=True, sep_sel=args.sep_sel) engine = alice_model.engine_ty(alice_model, args) reinforce = Reinforce(dialog, ctx_gen, args, engine, corpus, logger) reinforce.run() utils.save_model(alice.model, args.output_model_file)
def main(): parser = argparse.ArgumentParser(description='chat utility') parser.add_argument('--model_file', type=str, help='model file') parser.add_argument('--domain', type=str, default='object_division', help='domain for the dialogue') parser.add_argument('--context_file', type=str, default='', help='context file') parser.add_argument('--temperature', type=float, default=1.0, help='temperature') parser.add_argument('--num_types', type=int, default=3, help='number of object types') parser.add_argument('--num_objects', type=int, default=6, help='total number of objects') parser.add_argument('--max_score', type=int, default=10, help='max score per object') parser.add_argument( '--score_threshold', type=int, default=6, help='successful dialog should have more than score_threshold in score' ) parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--smart_ai', action='store_true', default=False, help='make AI smart again') parser.add_argument('--ai_starts', action='store_true', default=False, help='allow AI to start the dialog') parser.add_argument('--ref_text', type=str, help='file with the reference text') parser.add_argument('--cuda', action='store_true', default=False, help='use CUDA') args = parser.parse_args() utils.use_cuda(args.cuda) utils.set_seed(args.seed) human = HumanAgent(domain.get_domain(args.domain)) alice_ty = RnnRolloutAgent if args.smart_ai else HierarchicalAgent ai = alice_ty(utils.load_model(args.model_file), args) agents = [ai, human] if args.ai_starts else [human, ai] dialog = Dialog(agents, args) logger = DialogLogger(verbose=True) if args.context_file == '': ctx_gen = ManualContextGenerator(args.num_types, args.num_objects, args.max_score) else: ctx_gen = ContextGenerator(args.context_file) chat = Chat(dialog, ctx_gen, logger) chat.run()
def main(): parser = argparse.ArgumentParser(description='Negotiator') parser.add_argument('--dataset', type=str, default='./data/negotiate/val.txt', help='location of the dataset') parser.add_argument('--model_file', type=str, help='model file') parser.add_argument('--smart_ai', action='store_true', default=False, help='to use rollouts') parser.add_argument('--seed', type=int, default=1, help='random seed') parser.add_argument('--temperature', type=float, default=1.0, help='temperature') parser.add_argument('--domain', type=str, default='object_division', help='domain for the dialogue') parser.add_argument('--log_file', type=str, default='', help='log file') args = parser.parse_args() utils.set_seed(args.seed) model = utils.load_model(args.model_file) ai = LstmAgent(model, args) logger = DialogLogger(verbose=True, log_file=args.log_file) domain = get_domain(args.domain) score_func = rollout if args.smart_ai else likelihood dataset, sents = read_dataset(args.dataset) ranks, n, k = 0, 0, 0 for ctx, dialog in dataset: start_time = time.time() # start new conversation ai.feed_context(ctx) for sent, you in dialog: if you: # if it is your turn to say, take the target word and compute its rank rank = compute_rank(sent, sents, ai, domain, args.temperature, score_func) # compute lang_h for the groundtruth sentence enc = ai._encode(sent, ai.model.word_dict) _, ai.lang_h, lang_hs = ai.model.score_sent( enc, ai.lang_h, ai.ctx_h, args.temperature) # save hidden states and the utterance ai.lang_hs.append(lang_hs) ai.words.append(ai.model.word2var('YOU:')) ai.words.append(Variable(enc)) ranks += rank n += 1 else: ai.read(sent) k += 1 time_elapsed = time.time() - start_time logger.dump('dialogue %d | avg rank %.3f | raw %d/%d | time %.3f' % (k, 1. * ranks / n, ranks, n, time_elapsed)) logger.dump('final avg rank %.3f' % (1. * ranks / n))