class TransformerAgent(Agent): @staticmethod def add_cmdline_args(argparser): agent_args = argparser.add_argument_group('Agent parameters') agent_args.add_argument('-gpu', '--gpu', type=int, default=-1, help='which GPU to use') agent_args.add_argument('--no-cuda', type=bool, default=False, help='disable GPUs even if available. otherwise, will use GPUs if ' 'available on the device.') agent_args.add_argument('--rank_candidates', type=bool, default=False, help='Whether the model should parse candidates for ranking.') agent_args.add_argument('--sample', type=bool, default=False, help='Sampling of beam from beam search') agent_args.add_argument('--wild_mode', type=bool, default=False, help='') agent_args.add_argument('--replace_repeat', type=bool, default=True, help='') agent_args.add_argument('--replace_ngram', type=bool, default=True, help='') agent_args.add_argument('--detokenize', type=bool, default=True, help='') agent_args.add_argument('--emoji_prob', type=float, default=0.5, help='') agent_args.add_argument('--ngram_size', type=int, default=3, help='') agent_args.add_argument('--add_questions', type=float, default=0.3, help='') agent_args.add_argument('--clean_emoji', type=bool, default=True, help='') agent_args.add_argument('--check_grammar', type=bool, default=True, help='') agent_args.add_argument('--correct_generative', type=bool, default=True, help='') agent_args.add_argument('--split_into_sentences', type=bool, default=True, help='') agent_args.add_argument('--max_seq_len', type=int, default=128, help='') agent_args.add_argument('--beam_size', type=int, default=1, help='') agent_args.add_argument('--diversity_coef', type=float, default=0, help='') agent_args.add_argument('--diversity_groups', type=int, default=1, help='') agent_args.add_argument('--annealing_topk', type=float, default=None, help='') agent_args.add_argument('--annealing', type=float, default=0.0, help='') agent_args.add_argument('--length_penalty', type=float, default=0.6, help='') return argparser def __init__(self, opt, shared=None): super(TransformerAgent, self).__init__(opt, shared) self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(self.opt['gpu']) torch.set_grad_enabled(False) model_config = get_model_config() self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) self.reply_checker = ReplyChecker(correct_generative=self.opt['correct_generative'], split_into_sentences=self.opt['split_into_sentences']) self.replace_repeat = self.opt['replace_repeat'] self.replace_ngram = self.opt['replace_ngram'] self.ngram_size = self.opt['ngram_size'] self.detokenize = self.opt['detokenize'] self.emoji_prob = self.opt['emoji_prob'] self.add_questions = self.opt['add_questions'] self.beam_size = self.opt['beam_size'] self.clean_emoji = self.opt['clean_emoji'] self.check_grammar = self.opt['check_grammar'] # 'max_seq_len': 128, # 'beam_size': 1, # 'diversity_coef': 0, # 'diversity_groups': 1, # 'annealing_topk': None, # 'annealing': 0, # 'length_penalty': 0.6, if self.opt['annealing_topk'] is not None: assert self.opt['annealing_topk'] >= self.opt['beam_size'] assert self.opt['diversity_coef'] >= 0 assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0 if shared is None: self.model = TransformerModel(n_layers=model_config.n_layers, n_embeddings=len(self.vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=self.vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=self.vocab.bos_id, eos_id=self.vocab.eos_id, max_seq_len=self.opt['max_seq_len'], beam_size=self.opt['beam_size'], length_penalty=self.opt['length_penalty'], n_segments=model_config.n_segments, sample=self.opt['sample'], annealing_topk=self.opt['annealing_topk'], annealing=self.opt['annealing'], diversity_coef=self.opt['diversity_coef'], diversity_groups=self.opt['diversity_groups']) self.retrieval_bot = RetrievalBot() state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage) if 'model' in state_dict: state_dict = state_dict['model'] self.model.load_state_dict(state_dict) print('Weights loaded from {}'.format(model_config.checkpoint_path)) if self.use_cuda: self.model = self.model.cuda() self.model.eval() else: self.model = shared['model'] self.retrieval_bot = shared['retrieval'] self.reset() def _preprocess_text(self, text): if self.clean_emoji: text = clean_emoji(text) if self.check_grammar: text = syntax_fix(text).lower() return text def _parse(self, text): # todo: fix grammar mistakes? persona_info = [] dialog = [] for subtext in text.split('\n'): subtext = subtext.strip() if self.opt['wild_mode'] and len(self.history['info']) == 0 and len(self.history['dialog']) == 0: subtext = 'your persona: ' + subtext if subtext.startswith('your persona:'): subtext = subtext.replace('your persona:', '').strip() subtext = self._preprocess_text(subtext).strip() persona_info.append(subtext) else: subtext = self._preprocess_text(subtext).strip() dialog.append(subtext) return persona_info, dialog def observe(self, observation): if self.episode_done: self.reset() if 'text' in observation: text = observation['text'] info, dialog = self._parse(text) if info: self.history['str_info'] = ' '.join(info) self.history['str_dialog'].extend(dialog) info = sum([self.vocab.string2ids(i) for i in info], []) self.history['info'].extend(info) for i, d in enumerate(dialog, 1): d = self.vocab.string2ids(d) if i % 2 == 1: d = [self.vocab.talker1_bos_id] + d + [self.vocab.talker1_eos_id] else: d = [self.vocab.talker2_bos_id] + d + [self.vocab.talker2_eos_id] self.history['dialog'].extend(d) observation['agent'] = self self.episode_done = observation['episode_done'] self.observation = observation return observation def act(self): return self.batch_act([self.observation])[0] def _postprocess_text(self, reply, agent): str_reply = self.vocab.ids2string(reply) if self.replace_repeat: str_reply = agent.reply_checker.check_reply(str_reply, agent.history['str_dialog'][-1], agent.history['str_info']) if self.beam_size > 1 and random.uniform(0, 1) < self.add_questions and '?' not in str_reply: question = self.retrieval_bot.generate_question(list(agent.history['str_dialog']), agent.history['str_info']) if question is not None and question not in str_reply: str_reply = ' '.join([str_reply, question]) if self.replace_ngram: str_reply = ngram_replaser(agent.history['str_info'], str_reply, n=self.ngram_size) reply = self.vocab.string2ids(str_reply) if self.detokenize: str_reply = detokenize(str_reply) if random.uniform(0, 1) < self.emoji_prob: str_reply = ' '.join([str_reply, pick_emoji(str_reply)]) return str_reply, reply def batch_act(self, observations): def is_valid_history(history): return len(history['dialog']) def to_tensor(string): ids = [self.vocab.bos_id] + self.vocab.string2ids(string) + [self.vocab.eos_id] return torch.tensor(ids, dtype=torch.long) batch_reply = [{'id': self.getID(), 'text': '', 'text_candidates': []} for _ in range(len(observations))] valid_ids = [i for i, obs in enumerate(observations) if is_valid_history(obs['agent'].history)] batch_size = len(valid_ids) if batch_size == 0: return batch_reply try: valid_observations = [observations[i] for i in valid_ids] infos = [obs['agent'].history['info'][:self.model.n_pos_embeddings-3] for obs in valid_observations] infos = [([self.vocab.info_bos_id] + ifo + [self.vocab.info_eos_id] if len(ifo) else ifo) for ifo in infos] dialogs = [list(obs['agent'].history['dialog'])[-self.model.n_pos_embeddings+1:] for obs in valid_observations] contexts = [] if max(map(len, infos)) > 0: infos = [torch.tensor(i, dtype=torch.long) for i in infos] infos = pad_sequence(infos, batch_first=True, padding_value=self.model.padding_idx) if self.use_cuda: infos = infos.cuda() contexts.append(infos) if max(map(len, dialogs)) > 0: dialogs = [torch.tensor(d, dtype=torch.long) for d in dialogs] dialogs = pad_sequence(dialogs, batch_first=True, padding_value=self.model.padding_idx) if self.use_cuda: dialogs = dialogs.cuda() contexts.append(dialogs) enc_contexts = [self.model.encode(c) for c in contexts] pred_texts = self.model.beam_search(enc_contexts) for i in range(batch_size): pred_text_str, pred_text = self._postprocess_text(pred_texts[i], valid_observations[i]['agent']) valid_observations[i]['agent'].history['dialog'].extend([self.vocab.talker2_bos_id] + pred_text + [self.vocab.talker2_eos_id]) batch_reply[valid_ids[i]]['text'] = pred_text_str batch_reply[valid_ids[i]]['episode_done'] = valid_observations[i]['agent'].episode_done if self.opt['rank_candidates']: candidates = [list(obs.get('label_candidates', [])) for obs in valid_observations] lens_candidates = [len(c) for c in candidates] if max(lens_candidates) > 0: candidates = [c + ['' for _ in range(max(lens_candidates) - len(c))] for c in candidates] scores = [[] for _ in range(len(candidates))] for i in range(max(lens_candidates)): current_cands = [to_tensor(c[i])[:self.model.n_pos_embeddings-1] for c in candidates] current_cands = pad_sequence(current_cands, batch_first=True, padding_value=self.model.padding_idx) if self.use_cuda: current_cands = current_cands.cuda() logits = self.model.decode(current_cands[:, :-1], enc_contexts) log_probas = F.log_softmax(logits, dim=-1) log_probas = torch.gather(log_probas, -1, current_cands[:, 1:].unsqueeze(-1)).squeeze(-1) log_probas.masked_fill_(current_cands[:, 1:].eq(self.model.padding_idx), 0) current_lens = current_cands[:, 1:].ne(self.model.padding_idx).float().sum(dim=-1) current_scores = log_probas.sum(dim=-1) / current_lens for k, s in enumerate(current_scores): if i < lens_candidates[k]: scores[k].append(s.item()) ranked_ids = [sorted(range(len(s)), key=lambda k: s[k], reverse=True) for s in scores] ranked_strings = [[c[i] for i in ids] for ids, c in zip(ranked_ids, candidates)] for i in range(batch_size): batch_reply[valid_ids[i]]['text_candidates'] = ranked_strings[i] except Exception as e: # raise e print(e) return batch_reply def share(self): shared = super(TransformerAgent, self).share() shared['opt'] = self.opt shared['model'] = self.model shared['retrieval'] = self.retrieval_bot return shared def reset(self): self.history = {'str_info': None, 'str_dialog': deque(DIALOG_SIZE * ['None'], maxlen=DIALOG_SIZE), 'info': [], 'dialog': deque(maxlen=self.model.n_pos_embeddings-1)} self.episode_done = True self.observation = None self.reply_checker.clean()
def __init__(self, opt, shared=None): super(TransformerAgent, self).__init__(opt, shared) self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(self.opt['gpu']) torch.set_grad_enabled(False) model_config = get_model_config() self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) self.reply_checker = ReplyChecker(correct_generative=self.opt['correct_generative'], split_into_sentences=self.opt['split_into_sentences']) self.replace_repeat = self.opt['replace_repeat'] self.replace_ngram = self.opt['replace_ngram'] self.ngram_size = self.opt['ngram_size'] self.detokenize = self.opt['detokenize'] self.emoji_prob = self.opt['emoji_prob'] self.add_questions = self.opt['add_questions'] self.beam_size = self.opt['beam_size'] self.clean_emoji = self.opt['clean_emoji'] self.check_grammar = self.opt['check_grammar'] # 'max_seq_len': 128, # 'beam_size': 1, # 'diversity_coef': 0, # 'diversity_groups': 1, # 'annealing_topk': None, # 'annealing': 0, # 'length_penalty': 0.6, if self.opt['annealing_topk'] is not None: assert self.opt['annealing_topk'] >= self.opt['beam_size'] assert self.opt['diversity_coef'] >= 0 assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0 if shared is None: self.model = TransformerModel(n_layers=model_config.n_layers, n_embeddings=len(self.vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=self.vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=self.vocab.bos_id, eos_id=self.vocab.eos_id, max_seq_len=self.opt['max_seq_len'], beam_size=self.opt['beam_size'], length_penalty=self.opt['length_penalty'], n_segments=model_config.n_segments, sample=self.opt['sample'], annealing_topk=self.opt['annealing_topk'], annealing=self.opt['annealing'], diversity_coef=self.opt['diversity_coef'], diversity_groups=self.opt['diversity_groups']) self.retrieval_bot = RetrievalBot() state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage) if 'model' in state_dict: state_dict = state_dict['model'] self.model.load_state_dict(state_dict) print('Weights loaded from {}'.format(model_config.checkpoint_path)) if self.use_cuda: self.model = self.model.cuda() self.model.eval() else: self.model = shared['model'] self.retrieval_bot = shared['retrieval'] self.reset()
def main(): model_config = get_model_config_dialog() test_config = get_test_config_dialog() set_seed(test_config.seed) device = torch.device(test_config.device) vocab = myVocab(model_config.vocab_path) transformer = TransformerModel( n_layers=model_config.n_layers, n_embeddings=len(vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=vocab.bos_id, eos_id=vocab.eos_id, max_seq_len=model_config.max_seq_len, beam_size=model_config.beam_size, length_penalty=model_config.length_penalty, n_segments=model_config.n_segments, annealing_topk=model_config.annealing_topk, annealing=model_config.annealing, diversity_coef=model_config.diversity_coef, diversity_groups=model_config.diversity_groups) transformer = transformer.to(device) state_dict = torch.load(test_config.last_checkpoint_path, map_location=device) temp = dict(state_dict['model']) keys = list(temp.keys()) for key in keys: # new_key = '.'.join([i for i in key.split('.') if i != 'module']) new_key = key.replace('.module', '') temp[new_key] = temp.pop(key) transformer.load_state_dict(temp) transformer.eval() print('Weights loaded from {}'.format(test_config.last_checkpoint_path)) def answer(message): message = ' '.join(message) message = vocab.string2ids(message) message = [vocab.bos_id] + message + [vocab.eos_id] message = message[:60] # print(message) contexts = [ torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0 ] prediction = transformer.predict(contexts)[0] prediction_str = vocab.ids2string(prediction) return prediction_str def answer_beams(message): message = ' '.join(message) message = vocab.string2ids(message) message = [vocab.bos_id] + message + [vocab.eos_id] message = message[:30] # print(message) contexts = [ torch.tensor([c], dtype=torch.long, device=device) for c in [message] if len(c) > 0 ] predictions = transformer.predict_beams(contexts)[0] prediction_strs = [ vocab.ids2string(prediction) for prediction in predictions ] return prediction_strs ''' with open('data/test200_output_noinit_noweight.txt', 'w', encoding='utf8') as fw: with open('data/test200.txt', 'r', encoding='utf8') as fr: lines = fr.readlines() for line in lines: post, response = line.strip('\n').replace(' ', '').split('\t') ans = answer(post) fw.write('source:' + post + '\t' + 'target:' + response + '\t' + 'answer:' + ans + '\n') ''' ''' while True: message = input('>') ans = answer(message) print(ans) ''' while True: message = input('>') ans = answer_beams(message) for i in ans: print(i)
def main(): parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=-1, help="Distributed training.") parser.add_argument('--server_ip', type=str, default='', help="Used for debugging on GPU machine.") parser.add_argument('--server_port', type=str, default='', help="Used for debugging on GPU machine.") args = parser.parse_args() logging.basicConfig( format='%(asctime)s - %(levelname)s - %(name)s - %(message)s', datefmt='%m/%d/%Y %H:%M:%S', level=logging.INFO if args.local_rank in [-1, 0] else logging.ERROR) logger = logging.getLogger(__file__) if args.server_ip and args.server_port and args.local_rank in [-1, 0]: # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script import ptvsd print("Waiting for debugger attach") ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True) ptvsd.wait_for_attach() model_config = get_model_config() trainer_config = get_trainer_config() # Log only on main process if args.local_rank not in [-1, 0]: sys.stdout = open(f"./runs/log_distributed_{args.local_rank}", "w") # dump sdtout writer = DummyWriter() else: writer = SummaryWriter(comment=trainer_config.writer_comment) logger.info("model config: {}".format(model_config)) logger.info("trainer config: {}".format(trainer_config)) log_dir = writer.log_dir interrupt_checkpoint_path = os.path.join( log_dir, trainer_config.interrupt_checkpoint_path) last_checkpoint_path = os.path.join(log_dir, trainer_config.last_checkpoint_path) logger.info( "Logging to {}".format(log_dir) ) # Let's save everything on an experiment in the ./runs/XXX/directory if args.local_rank in [-1, 0]: with open(os.path.join(log_dir, "model_config.json"), "w") as f: json.dump(model_config, f) with open(os.path.join(log_dir, "trainer_config.json"), "w") as f: json.dump(trainer_config, f) set_seed(trainer_config.seed) device = torch.device(trainer_config.device) vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path, zero_shot=trainer_config.zero_shot) transformer = TransformerModel( n_layers=model_config.n_layers, n_embeddings=len(vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, normalize_embeddings=model_config.normalize_embeddings, bos_id=vocab.bos_id, eos_id=vocab.eos_id, sent_dialog_id=vocab.sent_dialog_id, max_seq_len=model_config.max_seq_len, beam_size=model_config.beam_size, length_penalty=model_config.length_penalty, n_segments=model_config.n_segments, annealing_topk=model_config.annealing_topk, annealing=model_config.annealing, diversity_coef=model_config.diversity_coef, diversity_groups=model_config.diversity_groups, multiple_choice_head=model_config.multiple_choice_head, constant_embedding=model_config.constant_embedding, single_input=model_config.single_input, dialog_embeddings=model_config.dialog_embeddings, share_models=model_config.share_models, successive_attention=model_config.successive_attention, sparse_embeddings=model_config.sparse_embeddings, shared_attention=model_config.shared_attention, bs_temperature=model_config.bs_temperature, bs_nucleus_p=model_config.bs_nucleus_p, vocab=None) # for beam search debugging if not trainer_config.load_last: load_openai_weights(transformer.transformer_module, trainer_config.openai_parameters_dir, n_special_tokens=vocab.n_special_tokens) if not model_config.share_models: load_openai_weights(transformer.encoder_module, trainer_config.openai_parameters_dir, n_special_tokens=vocab.n_special_tokens) logger.info('OpenAI weights loaded from {}, model shared: {}'.format( trainer_config.openai_parameters_dir, model_config.share_models)) logger.info('loading datasets') train_dataset = FacebookDataset( trainer_config.train_datasets, vocab, max_lengths=(transformer.n_pos_embeddings - 1) // (3 if model_config.single_input else 1), # A bit restrictive here dialog_embeddings=model_config.dialog_embeddings, cache=trainer_config.train_datasets_cache, use_start_end=model_config.use_start_end, negative_samples=trainer_config.negative_samples, augment=trainer_config.persona_augment, aug_syn_proba=trainer_config.persona_aug_syn_proba, limit_size=trainer_config.limit_train_size) test_dataset = FacebookDataset( trainer_config.test_datasets, vocab, max_lengths=(transformer.n_pos_embeddings - 1) // (3 if model_config.single_input else 1), # A bit restrictive here dialog_embeddings=model_config.dialog_embeddings, cache=trainer_config.test_datasets_cache, use_start_end=model_config.use_start_end, negative_samples=-1, # Keep all negative samples augment=False, aug_syn_proba=0.0, limit_size=trainer_config.limit_eval_size) logger.info( f'train dataset {len(train_dataset)} test dataset {(test_dataset)}') if args.local_rank != -1: torch.cuda.set_device(args.local_rank) device = torch.device('cuda', args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') transformer.distribute(device) model_trainer = Trainer( transformer, train_dataset, writer, test_dataset, train_batch_size=trainer_config.train_batch_size, batch_split=trainer_config.batch_split, test_batch_size=trainer_config.test_batch_size, lr=trainer_config.lr, lr_warmup=trainer_config.lr_warmup, weight_decay=trainer_config.weight_decay, s2s_weight=trainer_config.s2s_weight, lm_weight=trainer_config.lm_weight, risk_weight=trainer_config.risk_weight, hits_weight=trainer_config.hits_weight, single_input=model_config.single_input, n_jobs=trainer_config.n_jobs, clip_grad=trainer_config.clip_grad, device=device, ignore_idxs=vocab.special_tokens_ids, local_rank=args.local_rank, apex_level=model_config.apex_level, apex_loss_scale=trainer_config.apex_loss_scale, linear_schedule=trainer_config.linear_schedule, n_epochs=trainer_config.n_epochs, evaluate_full_sequences=trainer_config.evaluate_full_sequences) if trainer_config.load_last: state_dict = torch.load(trainer_config.load_last, map_location=device) model_trainer.load_state_dict(state_dict) logger.info('Weights loaded from {}'.format(trainer_config.load_last)) # helpers ----------------------------------------------------- def external_metrics_func(full_references, full_predictions, epoch, metric=None): references_file_path = os.path.join( writer.log_dir, trainer_config.eval_references_file + "_{}".format(epoch)) predictions_file_path = os.path.join( writer.log_dir, trainer_config.eval_predictions_file + "_{}".format(epoch)) with open(references_file_path, 'w', encoding='utf-8') as f: f.write(unicode('\n'.join(full_references))) with open(predictions_file_path, 'w', encoding='utf-8') as f: f.write(unicode('\n'.join(full_predictions))) if metric is not None: return specified_nlp_metric([references_file_path], predictions_file_path, metric) nist, bleu, meteor, entropy, div, avg_len = nlp_metrics( [references_file_path], predictions_file_path) metrics = {'meteor': meteor, 'avg_len': avg_len} for name, metric in (('nist', nist), ('entropy', entropy), ('div', div), ('bleu', bleu)): for i, m in enumerate(metric, 1): metrics['{}_{}'.format(name, i)] = m return metrics def save_func(epoch): if epoch != -1: torch.save(model_trainer.state_dict(), last_checkpoint_path) def sample_text_func(epoch): n_samples = 0 model_trainer.model.eval() samples_idxs = random.sample(range(len(test_dataset)), n_samples) samples = [test_dataset[idx] for idx in samples_idxs] for persona_info, dialog, target, _ in samples: contexts = [ torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [persona_info, dialog] if len(c) > 0 ] prediction = model_trainer.model.predict(contexts)[0] persona_info_str = vocab.ids2string(persona_info[1:-1]) dialog_str = vocab.ids2string(dialog) dialog_str = dialog_str.replace(vocab.talker1_bos, '\n\t- ').replace( vocab.talker2_bos, '\n\t- ') dialog_str = dialog_str.replace(vocab.talker1_eos, '').replace(vocab.talker2_eos, '') target_str = vocab.ids2string(target[1:-1]) prediction_str = vocab.ids2string(prediction) logger.info('\n') logger.info('Persona info:\n\t{}'.format(persona_info_str)) logger.info('Dialog:{}'.format(dialog_str)) logger.info('Target:\n\t{}'.format(target_str)) logger.info('Prediction:\n\t{}'.format(prediction_str)) def test_func(epoch): if (epoch + 1) % trainer_config.test_period == 0: metric_funcs = {'f1_score': f1_score} model_trainer.test(metric_funcs, external_metrics_func, epoch) def f1_risk(predictions, targets): scores = f1_score(predictions, targets, average=False) assert all([0 <= s <= 1.0 for s in scores]) return [1 - s for s in scores] def get_risk_metric_func(risk_metric): """ risk_metric selected in: f1, meteor, avg_len, nist_{1, 2, 3, 4}, entropy_{1, 2, 3, 4}, div_{1, 2}, bleu_{1, 2, 3, 4} """ def external_metric_risk(predictions, targets): string_targets = list(vocab.ids2string(t) for t in targets) string_predictions = list(vocab.ids2string(t) for t in predictions) metrics = [ external_metrics_func([t], [p], epoch=-1, metric=risk_metric) for p, t in zip(string_predictions, string_targets) ] if any([s in risk_metric for s in ['entropy', 'nist', 'avg_len']]): return [-m for m in metrics] assert all([0 <= s <= 1.0 for s in metrics]), metrics return [1 - m for m in metrics] if risk_metric == 'f1': return f1_risk return external_metric_risk # helpers ----------------------------------------------------- try: model_trainer.train( after_epoch_funcs=[save_func, sample_text_func, test_func], risk_func=get_risk_metric_func(trainer_config.risk_metric)) except (KeyboardInterrupt, Exception, RuntimeError) as e: if args.local_rank in [-1, 0]: torch.save(model_trainer.state_dict(), interrupt_checkpoint_path) raise e
def main(): model_config = get_model_config() trainer_config = get_trainer_config() set_seed(trainer_config.seed) device = torch.device(trainer_config.device) vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) transformer = TransformerModel( n_layers=model_config.n_layers, n_embeddings=len(vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=vocab.bos_id, eos_id=vocab.eos_id, max_seq_len=model_config.max_seq_len, beam_size=model_config.beam_size, length_penalty=model_config.length_penalty, n_segments=model_config.n_segments, annealing_topk=model_config.annealing_topk, annealing=model_config.annealing, diversity_coef=model_config.diversity_coef, diversity_groups=model_config.diversity_groups) if not trainer_config.load_last: load_openai_weights(transformer.transformer_module, trainer_config.openai_parameters_dir, n_special_tokens=vocab.n_special_tokens) print('OpenAI weights loaded from {}'.format( trainer_config.openai_parameters_dir)) train_dataset = FacebookDataset(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1) test_dataset = FacebookDataset(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1) model_trainer = Trainer(transformer, train_dataset, test_dataset, batch_size=trainer_config.batch_size, batch_split=trainer_config.batch_split, lr=trainer_config.lr, lr_warmup=trainer_config.lr_warmup, lm_weight=trainer_config.lm_weight, risk_weight=trainer_config.risk_weight, n_jobs=trainer_config.n_jobs, clip_grad=trainer_config.clip_grad, device=device, ignore_idxs=vocab.special_tokens_ids) if trainer_config.load_last: state_dict = torch.load(trainer_config.last_checkpoint_path, map_location=device) model_trainer.load_state_dict(state_dict) print('Weights loaded from {}'.format( trainer_config.last_checkpoint_path)) # helpers ----------------------------------------------------- def save_func(epoch): torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path) def sample_text_func(epoch): n_samples = 5 samples_idxs = random.sample(range(len(test_dataset)), n_samples) samples = [test_dataset[idx] for idx in samples_idxs] for persona_info, dialog, target in samples: contexts = [ torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [persona_info, dialog] if len(c) > 0 ] prediction = model_trainer.model.predict(contexts)[0] persona_info_str = vocab.ids2string(persona_info[1:-1]) dialog_str = vocab.ids2string(dialog) dialog_str = dialog_str.replace(vocab.talker1_bos, '\n\t- ').replace( vocab.talker2_bos, '\n\t- ') dialog_str = dialog_str.replace(vocab.talker1_eos, '').replace(vocab.talker2_eos, '') target_str = vocab.ids2string(target[1:-1]) prediction_str = vocab.ids2string(prediction) print('\n') print('Persona info:\n\t{}'.format(persona_info_str)) print('Dialog:{}'.format(dialog_str)) print('Target:\n\t{}'.format(target_str)) print('Prediction:\n\t{}'.format(prediction_str)) def test_func(epoch): if (epoch + 1) % trainer_config.test_period == 0: metric_funcs = {'f1_score': f1_score} model_trainer.test(metric_funcs) def f1_risk(predictions, targets): scores = f1_score(predictions, targets, average=False) return [1 - s for s in scores] # helpers ----------------------------------------------------- try: model_trainer.train( trainer_config.n_epochs, after_epoch_funcs=[save_func, sample_text_func, test_func], risk_func=f1_risk) except (KeyboardInterrupt, Exception, RuntimeError) as e: torch.save(model_trainer.state_dict(), trainer_config.interrupt_checkpoint_path) raise e
def main(): model_config = get_model_config_dialog() trainer_config = get_trainer_config_dialog() set_seed(trainer_config.seed) device = torch.device(trainer_config.device) # zrs parser = argparse.ArgumentParser() parser.add_argument("--local_rank", type=int, default=-1) args = parser.parse_args() distributed = (args.local_rank != -1) if distributed: print(args.local_rank) torch.cuda.set_device(args.local_rank) device = torch.device("cuda", args.local_rank) torch.distributed.init_process_group(backend='nccl', init_method='env://') vocab = myVocab(model_config.vocab_path) transformer = TransformerModel( n_layers=model_config.n_layers, n_embeddings=len(vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=vocab.bos_id, eos_id=vocab.eos_id, max_seq_len=model_config.max_seq_len, beam_size=model_config.beam_size, length_penalty=model_config.length_penalty, n_segments=model_config.n_segments, annealing_topk=model_config.annealing_topk, temperature=model_config.temperature, annealing=model_config.annealing, diversity_coef=model_config.diversity_coef, diversity_groups=model_config.diversity_groups) if not trainer_config.load_last: openai_model = torch.load(trainer_config.openai_parameters_dir, map_location=device) openai_model.pop('decoder.pre_softmax.weight') b = list(openai_model.keys()) for i in b: temp = i.split('.') keep = True for j in range(model_config.n_layers, 12): if str(j) in temp: keep = False break if keep: openai_model[i.split('.', 1)[1]] = openai_model.pop(i) else: print(i) openai_model.pop(i) #openai_model[i.split('.', 1)[1]] = openai_model.pop(i) transformer.transformer_module.load_state_dict(openai_model, strict=True) # load_openai_weights_chinese(transformer.transformer_module, trainer_config.openai_parameters_dir) print('OpenAI weights chinese loaded from {}'.format( trainer_config.openai_parameters_dir)) train_dataset = S2sDataset_dialog(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1) test_dataset = S2sDataset_dialog(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1) model_trainer = Trainer( transformer, train_dataset, test_dataset, batch_size=trainer_config.batch_size, batch_split=trainer_config.batch_split, lr=trainer_config.lr, lr_warmup=trainer_config.lr_warmup, lm_weight=trainer_config.lm_weight, risk_weight=trainer_config.risk_weight, n_jobs=trainer_config.n_jobs, clip_grad=trainer_config.clip_grad, # label_smoothing=trainer_config.label_smoothing, device=device, ignore_idxs=vocab.special_tokens_ids, distributed=distributed) if distributed: model_trainer.model.transformer_module = DistributedDataParallel( model_trainer.model.transformer_module, device_ids=[args.local_rank], output_device=args.local_rank) start_epoch = 0 init_epoch = 0 if trainer_config.load_last: state_dict = torch.load(trainer_config.last_checkpoint_path + str(init_epoch - 1), map_location=device) model_trainer.load_state_dict(state_dict) # start_epoch = int(cop.sub('', trainer_config.last_checkpoint_path.split('/')[-1])) + 1 start_epoch = init_epoch print('Weights loaded from {}'.format( trainer_config.last_checkpoint_path + str(init_epoch - 1))) # helpers ----------------------------------------------------- def save_func(epoch): dirs = '/'.join(trainer_config.last_checkpoint_path.split('/')[:-1]) if not os.path.exists(dirs): os.makedirs(dirs) torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path) torch.save(model_trainer.state_dict(), trainer_config.last_checkpoint_path + str(epoch)) if os.path.exists(trainer_config.last_checkpoint_path + str(epoch - 100)): os.remove(trainer_config.last_checkpoint_path + str(epoch - 100)) def sample_text_func(epoch): n_samples = 5 samples_idxs = random.sample(range(len(test_dataset)), n_samples) samples = [test_dataset[idx] for idx in samples_idxs] for source, target in samples: contexts = [ torch.tensor([c], dtype=torch.long, device=model_trainer.device) for c in [source] if len(c) > 0 ] prediction = model_trainer.model.predict(contexts)[0] source_str = vocab.ids2string(source) target_str = vocab.ids2string(target[1:-1]) prediction_str = vocab.ids2string(prediction) print('\n') print('Source:{}'.format(source_str)) print('Target:\n\t{}'.format(target_str)) print('Prediction:\n\t{}'.format(prediction_str)) def test_func(epoch): if (epoch + 1) % trainer_config.test_period == 0: metric_funcs = {'f1_score': f1_score} model_trainer.test(metric_funcs) def f1_risk(predictions, targets): scores = f1_score(predictions, targets, average=False) return [1 - s for s in scores] # helpers ----------------------------------------------------- # model_trainer.model.transformer_module = nn.DataParallel(model_trainer.model.transformer_module, device_ids=[0, 1]) try: if args.local_rank in [-1, 0]: model_trainer.train( start_epoch, trainer_config.n_epochs, after_epoch_funcs=[save_func, sample_text_func, test_func], risk_func=f1_risk) else: model_trainer.train(start_epoch, trainer_config.n_epochs) except (KeyboardInterrupt, Exception, RuntimeError) as e: torch.save(model_trainer.state_dict(), trainer_config.interrupt_checkpoint_path) raise e
def __init__(self, opt, shared=None): super(TransformerAgent, self).__init__(opt, shared) self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(self.opt['gpu']) torch.set_grad_enabled(False) model_config = get_model_config() self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) self.dialog_embeddings = model_config.dialog_embeddings self.use_start_end = model_config.use_start_end self.single_input = model_config.single_input self.apex_level = model_config.apex_level # 'max_seq_len': 128, # 'beam_size': 1, # 'diversity_coef': 0, # 'diversity_groups': 1, # 'annealing_topk': None, # 'annealing': 0, # 'length_penalty': 0.6, self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) if self.opt['annealing_topk'] is not None: assert self.opt['annealing_topk'] > self.opt['beam_size'] assert self.opt['diversity_coef'] >= 0 assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0 if shared is None: self.model = TransformerModel(n_layers=model_config.n_layers, n_embeddings=len(self.vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=self.vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=self.vocab.bos_id, eos_id=self.vocab.eos_id, sent_dialog_id=self.vocab.sent_dialog_id, max_seq_len=self.opt['max_seq_len'], beam_size=self.opt['beam_size'], length_penalty=self.opt['length_penalty'], n_segments=model_config.n_segments, sample=self.opt['sample'], annealing_topk=self.opt['annealing_topk'], annealing=self.opt['annealing'], diversity_coef=self.opt['diversity_coef'], diversity_groups=self.opt['diversity_groups'], normalize_embeddings=model_config.normalize_embeddings, multiple_choice_head=model_config.multiple_choice_head, constant_embedding=model_config.constant_embedding, vocab=self.vocab, single_input=model_config.single_input, dialog_embeddings=model_config.dialog_embeddings, share_models=model_config.share_models, successive_attention=model_config.successive_attention, sparse_embeddings=model_config.sparse_embeddings, shared_attention=model_config.sparse_embeddings, bs_temperature=model_config.bs_temperature, bs_nucleus_p=model_config.bs_nucleus_p ) state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage) if 'model' in state_dict: state_dict = state_dict['model'] self.model.load_state_dict(state_dict) print('Weights loaded from {}'.format(model_config.checkpoint_path)) if self.use_cuda: self.model = self.model.cuda() self.model.eval() self.model = apex_model(self.model, apex_level=self.apex_level) else: self.model = shared['model'] self.reset()
class TransformerAgent(Agent): @staticmethod def add_cmdline_args(argparser): agent_args = argparser.add_argument_group('Agent parameters') agent_args.add_argument('-gpu', '--gpu', type=int, default=-1, help='which GPU to use') agent_args.add_argument('--no-cuda', type=bool, default=False, help='disable GPUs even if available. otherwise, will use GPUs if ' 'available on the device.') agent_args.add_argument('--rank_candidates', type=bool, default=False, help='Whether the model should parse candidates for ranking.') agent_args.add_argument('--sample', type=bool, default=False, help='Sampling of beam from beam search') agent_args.add_argument('--max_seq_len', type=int, default=128, help='') agent_args.add_argument('--beam_size', type=int, default=1, help='') agent_args.add_argument('--diversity_coef', type=float, default=0, help='') agent_args.add_argument('--diversity_groups', type=int, default=1, help='') agent_args.add_argument('--annealing_topk', type=float, default=None, help='') agent_args.add_argument('--annealing', type=float, default=0.0, help='') agent_args.add_argument('--length_penalty', type=float, default=0.6, help='') return argparser def __init__(self, opt, shared=None): super(TransformerAgent, self).__init__(opt, shared) self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available() if self.use_cuda: torch.cuda.set_device(self.opt['gpu']) torch.set_grad_enabled(False) model_config = get_model_config() self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) self.dialog_embeddings = model_config.dialog_embeddings self.use_start_end = model_config.use_start_end self.single_input = model_config.single_input self.apex_level = model_config.apex_level # 'max_seq_len': 128, # 'beam_size': 1, # 'diversity_coef': 0, # 'diversity_groups': 1, # 'annealing_topk': None, # 'annealing': 0, # 'length_penalty': 0.6, self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) if self.opt['annealing_topk'] is not None: assert self.opt['annealing_topk'] > self.opt['beam_size'] assert self.opt['diversity_coef'] >= 0 assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0 if shared is None: self.model = TransformerModel(n_layers=model_config.n_layers, n_embeddings=len(self.vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=self.vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=self.vocab.bos_id, eos_id=self.vocab.eos_id, sent_dialog_id=self.vocab.sent_dialog_id, max_seq_len=self.opt['max_seq_len'], beam_size=self.opt['beam_size'], length_penalty=self.opt['length_penalty'], n_segments=model_config.n_segments, sample=self.opt['sample'], annealing_topk=self.opt['annealing_topk'], annealing=self.opt['annealing'], diversity_coef=self.opt['diversity_coef'], diversity_groups=self.opt['diversity_groups'], normalize_embeddings=model_config.normalize_embeddings, multiple_choice_head=model_config.multiple_choice_head, constant_embedding=model_config.constant_embedding, vocab=self.vocab, single_input=model_config.single_input, dialog_embeddings=model_config.dialog_embeddings, share_models=model_config.share_models, successive_attention=model_config.successive_attention, sparse_embeddings=model_config.sparse_embeddings, shared_attention=model_config.sparse_embeddings, bs_temperature=model_config.bs_temperature, bs_nucleus_p=model_config.bs_nucleus_p ) state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage) if 'model' in state_dict: state_dict = state_dict['model'] self.model.load_state_dict(state_dict) print('Weights loaded from {}'.format(model_config.checkpoint_path)) if self.use_cuda: self.model = self.model.cuda() self.model.eval() self.model = apex_model(self.model, apex_level=self.apex_level) else: self.model = shared['model'] self.reset() def _parse(self, text): persona_info = [] dialog = [] for subtext in text.split('\n'): subtext = subtext.strip() if subtext.startswith('your persona:'): subtext = subtext.replace('your persona:', '').strip() persona_info.append(subtext) else: dialog.append(subtext) return persona_info, dialog def _process_info(self, info): info = self._add_start_end(info[:self.model.n_pos_embeddings - (2 if self.use_start_end else 0)], self.vocab.info_bos_id, self.vocab.info_eos_id) info = self._add_dialog_embeddings(info, self.vocab.info_dialog_id) return info def _process_1st_replica(self, replica): replica = self._add_start_end(replica, self.vocab.talker1_bos_id, self.vocab.talker1_eos_id) replica = self._add_dialog_embeddings(replica, self.vocab.talker1_dialog_id) return replica def _process_2nd_replica(self, replica): replica = self._add_start_end(replica, self.vocab.talker2_bos_id, self.vocab.talker2_eos_id) replica = self._add_dialog_embeddings(replica, self.vocab.talker2_dialog_id) return replica def _add_dialog_embeddings(self, toks, dialog_tok): if self.dialog_embeddings: toks = [[t, dialog_tok] for t in toks] return toks def _add_start_end(self, toks, start, end): if self.use_start_end: toks = [start] + toks + [end] return toks def observe(self, observation): if self.episode_done: self.reset() if 'text' in observation: text = observation['text'] info, dialog = self._parse(text) info = sum([self.vocab.string2ids(i) for i in info], []) if info: prev_info = [h[0] for h in self.history['info']] if self.dialog_embeddings else self.history['info'] self.history['info'] = self._process_info(prev_info[1:-1] + info) for i, replica in enumerate(dialog, 1): replica = self.vocab.string2ids(replica) replica = self._process_1st_replica(replica) if i % 2 == 1 else self._process_2nd_replica(replica) self.history['dialog'].extend(replica) observation['agent'] = self self.episode_done = observation['episode_done'] self.observation = observation return observation def act(self): return self.batch_act([self.observation])[0] def batch_act(self, observations): def is_valid_history(history): return len(history['dialog']) def to_tensor(string): ids = [self.vocab.bos_id] + self.vocab.string2ids(string) + [self.vocab.eos_id] ids = self._add_dialog_embeddings(ids, self.vocab.sent_dialog_id) return torch.tensor(ids, dtype=torch.long) def to_cuda(data): if not self.use_cuda: return data if isinstance(data, (list, tuple, map)): return list(map(lambda x: x.cuda(), data)) return data.cuda() batch_reply = [{'id': self.getID(), 'text': '', 'text_candidates': []} for _ in range(len(observations))] valid_ids = [i for i, obs in enumerate(observations) if is_valid_history(obs['agent'].history)] batch_size = len(valid_ids) if batch_size == 0: return batch_reply try: valid_observations = [observations[i] for i in valid_ids] infos = [obs['agent'].history['info'] for obs in valid_observations] dialogs = [list(obs['agent'].history['dialog'])[-self.model.n_pos_embeddings+1:] for obs in valid_observations] contexts = [] if max(map(len, infos)) > 0: infos = [torch.tensor(i, dtype=torch.long) for i in infos] contexts.append(infos) if max(map(len, dialogs)) > 0: dialogs = [torch.tensor(d, dtype=torch.long) for d in dialogs] contexts.append(dialogs) if self.single_input: contexts = [torch.cat(c, dim=0) for c in zip(*contexts)] raw_context = contexts if self.opt['rank_candidates'] else None contexts = pad_sequence(contexts, batch_first=True, padding_value=self.model.padding_idx, left=True) else: contexts = map(lambda x: pad_sequence(x, batch_first=True, padding_value=self.model.padding_idx), contexts) contexts = to_cuda(contexts) pred_texts = self.model.predict(contexts) for i in range(batch_size): pred_toks = self._process_2nd_replica(pred_texts[i]) valid_observations[i]['agent'].history['dialog'].extend(pred_toks) batch_reply[valid_ids[i]]['text'] = self.vocab.ids2string(pred_texts[i]) batch_reply[valid_ids[i]]['episode_done'] = valid_observations[i]['agent'].episode_done if self.opt['rank_candidates']: enc_contexts = [self.model.encode(c) for c in contexts] if not self.single_input else [] candidates = [list(obs.get('label_candidates', [])) for obs in valid_observations] lens_candidates = [len(c) for c in candidates] if max(lens_candidates) > 0: candidates = [c + ['' for _ in range(max(lens_candidates) - len(c))] for c in candidates] scores = [[] for _ in range(len(candidates))] for i in range(max(lens_candidates)): current_cands = [to_tensor(c[i])[:self.model.n_pos_embeddings-1] for c in candidates] lens = map(lambda x: x.size(0), current_cands) if self.single_input else None if self.single_input: lens = map(lambda x: x.size(0), current_cands) current_cands = [torch.cat(c, dim=0)[-self.model.n_pos_embeddings:] for c in zip(raw_context, current_cands)] current_cands = to_cuda(current_cands) current_cands = pad_sequence(current_cands, batch_first=True, padding_value=self.model.padding_idx) logits = self.model.decode(current_cands[:, :-1], enc_contexts) if current_cands.dim() == 3: current_cands = current_cands[:, :, 0] log_probas = F.log_softmax(logits, dim=-1) log_probas = torch.gather(log_probas, -1, current_cands[:, 1:].unsqueeze(-1)).squeeze(-1) if self.single_input: # zero context for j, l in enumerate(lens): current_cands[j, :-l+1] = self.model.padding_idx log_probas.masked_fill_(current_cands[:, 1:].eq(self.model.padding_idx), 0) current_lens = current_cands[:, 1:].ne(self.model.padding_idx).float().sum(dim=-1) current_scores = log_probas.sum(dim=-1) / current_lens for k, s in enumerate(current_scores): if i < lens_candidates[k]: scores[k].append(s.item()) ranked_ids = [sorted(range(len(s)), key=lambda k: s[k], reverse=True) for s in scores] ranked_strings = [[c[i] for i in ids] for ids, c in zip(ranked_ids, candidates)] for i in range(batch_size): batch_reply[valid_ids[i]]['text_candidates'] = ranked_strings[i] except Exception as e: # raise e print(e) return batch_reply def share(self): shared = super(TransformerAgent, self).share() shared['opt'] = self.opt shared['model'] = self.model return shared def reset(self): self.history = {'info': [], 'dialog': deque(maxlen=self.model.n_pos_embeddings-1)} self.episode_done = True self.observation = None
def get_trainer(): model_config = get_model_config() trainer_config = get_trainer_config() set_seed(trainer_config.seed) device = torch.device(trainer_config.device) vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path) transformer = TransformerModel( n_layers=model_config.n_layers, n_embeddings=len(vocab), n_pos_embeddings=model_config.n_pos_embeddings, embeddings_size=model_config.embeddings_size, padding_idx=vocab.pad_id, n_heads=model_config.n_heads, dropout=model_config.dropout, embed_dropout=model_config.embed_dropout, attn_dropout=model_config.attn_dropout, ff_dropout=model_config.ff_dropout, bos_id=vocab.bos_id, eos_id=vocab.eos_id, max_seq_len=model_config.max_seq_len, beam_size=model_config.beam_size, length_penalty=model_config.length_penalty, n_segments=model_config.n_segments, annealing_topk=model_config.annealing_topk, annealing=model_config.annealing, diversity_coef=model_config.diversity_coef, diversity_groups=model_config.diversity_groups) if not trainer_config.load_last: load_openai_weights(transformer.transformer_module, trainer_config.openai_parameters_dir, n_special_tokens=vocab.n_special_tokens) print('OpenAI weights loaded from {}'.format( trainer_config.openai_parameters_dir)) train_dataset = FacebookDataset(trainer_config.train_datasets, vocab, transformer.n_pos_embeddings - 1) test_dataset = FacebookDataset(trainer_config.test_datasets, vocab, transformer.n_pos_embeddings - 1) model_trainer = Trainer(transformer, train_dataset, test_dataset, batch_size=trainer_config.batch_size, batch_split=trainer_config.batch_split, lr=trainer_config.lr, lr_warmup=trainer_config.lr_warmup, lm_weight=trainer_config.lm_weight, risk_weight=trainer_config.risk_weight, n_jobs=trainer_config.n_jobs, clip_grad=trainer_config.clip_grad, device=device, ignore_idxs=vocab.special_tokens_ids) if trainer_config.load_last: state_dict = torch.load(trainer_config.last_checkpoint_path, map_location=device) model_trainer.load_state_dict(state_dict) print('Weights loaded from {}'.format( trainer_config.last_checkpoint_path)) return model_trainer