class TransformerAgent(Agent):
    @staticmethod
    def add_cmdline_args(argparser):
        agent_args = argparser.add_argument_group('Agent parameters')
        agent_args.add_argument('-gpu', '--gpu', type=int, default=-1, 
                                help='which GPU to use')
        agent_args.add_argument('--no-cuda', type=bool, default=False,
                                help='disable GPUs even if available. otherwise, will use GPUs if '
                                     'available on the device.')
        agent_args.add_argument('--rank_candidates', type=bool, default=False,
                                help='Whether the model should parse candidates for ranking.')
        agent_args.add_argument('--sample', type=bool, default=False,
                                help='Sampling of beam from beam search')
        agent_args.add_argument('--wild_mode', type=bool, default=False,
                                help='')
        agent_args.add_argument('--replace_repeat', type=bool, default=True,
                                help='')
        agent_args.add_argument('--replace_ngram', type=bool, default=True,
                                help='')
        agent_args.add_argument('--detokenize', type=bool, default=True,
                                help='')
        agent_args.add_argument('--emoji_prob', type=float, default=0.5,
                                help='')
        agent_args.add_argument('--ngram_size', type=int, default=3,
                                help='')
        agent_args.add_argument('--add_questions', type=float, default=0.3,
                                help='')
        agent_args.add_argument('--clean_emoji', type=bool, default=True,
                                help='')
        agent_args.add_argument('--check_grammar', type=bool, default=True,
                                help='')
        agent_args.add_argument('--correct_generative', type=bool, default=True,
                                help='')
        agent_args.add_argument('--split_into_sentences', type=bool, default=True,
                                help='')

        agent_args.add_argument('--max_seq_len', type=int, default=128,
                                help='')
        agent_args.add_argument('--beam_size', type=int, default=1,
                                help='')
        agent_args.add_argument('--diversity_coef', type=float, default=0,
                                help='')
        agent_args.add_argument('--diversity_groups', type=int, default=1,
                                help='')
        agent_args.add_argument('--annealing_topk', type=float, default=None,
                                help='')
        agent_args.add_argument('--annealing', type=float, default=0.0,
                                help='')
        agent_args.add_argument('--length_penalty', type=float, default=0.6,
                                help='')
        
        return argparser

    def __init__(self, opt, shared=None):
        super(TransformerAgent, self).__init__(opt, shared)

        self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available()
        if self.use_cuda:
            torch.cuda.set_device(self.opt['gpu'])

        torch.set_grad_enabled(False)

        model_config = get_model_config()
        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)
        self.reply_checker = ReplyChecker(correct_generative=self.opt['correct_generative'],
                                          split_into_sentences=self.opt['split_into_sentences'])

        self.replace_repeat = self.opt['replace_repeat']
        self.replace_ngram = self.opt['replace_ngram']
        self.ngram_size = self.opt['ngram_size']
        self.detokenize = self.opt['detokenize']
        self.emoji_prob = self.opt['emoji_prob']
        self.add_questions = self.opt['add_questions']
        self.beam_size = self.opt['beam_size']

        self.clean_emoji = self.opt['clean_emoji']
        self.check_grammar = self.opt['check_grammar']

        # 'max_seq_len': 128,
        # 'beam_size': 1,
        # 'diversity_coef': 0,
        # 'diversity_groups': 1,
        # 'annealing_topk': None,
        # 'annealing': 0,
        # 'length_penalty': 0.6,

        if self.opt['annealing_topk'] is not None:
            assert self.opt['annealing_topk'] >= self.opt['beam_size']

        assert self.opt['diversity_coef'] >= 0
        assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0

        if shared is None:
            self.model = TransformerModel(n_layers=model_config.n_layers,
                                          n_embeddings=len(self.vocab),
                                          n_pos_embeddings=model_config.n_pos_embeddings,
                                          embeddings_size=model_config.embeddings_size,
                                          padding_idx=self.vocab.pad_id,
                                          n_heads=model_config.n_heads,
                                          dropout=model_config.dropout,
                                          embed_dropout=model_config.embed_dropout,
                                          attn_dropout=model_config.attn_dropout,
                                          ff_dropout=model_config.ff_dropout,
                                          bos_id=self.vocab.bos_id,
                                          eos_id=self.vocab.eos_id,
                                          max_seq_len=self.opt['max_seq_len'],
                                          beam_size=self.opt['beam_size'],
                                          length_penalty=self.opt['length_penalty'],
                                          n_segments=model_config.n_segments,
                                          sample=self.opt['sample'],
                                          annealing_topk=self.opt['annealing_topk'],
                                          annealing=self.opt['annealing'],
                                          diversity_coef=self.opt['diversity_coef'],
                                          diversity_groups=self.opt['diversity_groups'])
            self.retrieval_bot = RetrievalBot()

            state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage)
            if 'model' in state_dict:
                state_dict = state_dict['model']

            self.model.load_state_dict(state_dict)
            print('Weights loaded from {}'.format(model_config.checkpoint_path))

            if self.use_cuda:
                self.model = self.model.cuda()

            self.model.eval()

        else:
            self.model = shared['model']
            self.retrieval_bot = shared['retrieval']

        self.reset()

    def _preprocess_text(self, text):
        if self.clean_emoji:
            text = clean_emoji(text)

        if self.check_grammar:
            text = syntax_fix(text).lower()

        return text

    def _parse(self, text):
        # todo: fix grammar mistakes?
        persona_info = []
        dialog = []
        for subtext in text.split('\n'):
            subtext = subtext.strip()
            
            if self.opt['wild_mode'] and len(self.history['info']) == 0 and len(self.history['dialog']) == 0:
                subtext = 'your persona: ' + subtext

            if subtext.startswith('your persona:'):
                subtext = subtext.replace('your persona:', '').strip()
                subtext = self._preprocess_text(subtext).strip()
                persona_info.append(subtext)
            else:
                subtext = self._preprocess_text(subtext).strip()
                dialog.append(subtext)

        return persona_info, dialog

    def observe(self, observation):
        if self.episode_done:
            self.reset()

        if 'text' in observation:
            text = observation['text']
            info, dialog = self._parse(text)

            if info:
                self.history['str_info'] = ' '.join(info)
            self.history['str_dialog'].extend(dialog)
        
            info = sum([self.vocab.string2ids(i) for i in info], [])
            self.history['info'].extend(info)

            for i, d in enumerate(dialog, 1):
                d = self.vocab.string2ids(d)
                if i % 2 == 1:
                    d = [self.vocab.talker1_bos_id] + d + [self.vocab.talker1_eos_id]
                else:
                    d = [self.vocab.talker2_bos_id] + d + [self.vocab.talker2_eos_id]

                self.history['dialog'].extend(d)

        observation['agent'] = self        

        self.episode_done = observation['episode_done']
        self.observation = observation
        
        return observation
    
    def act(self):
        return self.batch_act([self.observation])[0]

    def _postprocess_text(self, reply, agent):
        str_reply = self.vocab.ids2string(reply)

        if self.replace_repeat:
            str_reply = agent.reply_checker.check_reply(str_reply,
                                                        agent.history['str_dialog'][-1],
                                                        agent.history['str_info'])

        if self.beam_size > 1 and random.uniform(0, 1) < self.add_questions and '?' not in str_reply:
            question = self.retrieval_bot.generate_question(list(agent.history['str_dialog']),
                                                            agent.history['str_info'])
            if question is not None and question not in str_reply:
                str_reply = ' '.join([str_reply, question])

        if self.replace_ngram:
            str_reply = ngram_replaser(agent.history['str_info'], str_reply, n=self.ngram_size)

        reply = self.vocab.string2ids(str_reply)

        if self.detokenize:
            str_reply = detokenize(str_reply)

        if random.uniform(0, 1) < self.emoji_prob:
            str_reply = ' '.join([str_reply, pick_emoji(str_reply)])

        return str_reply, reply

    def batch_act(self, observations):
        def is_valid_history(history):
            return len(history['dialog'])

        def to_tensor(string):
            ids = [self.vocab.bos_id] + self.vocab.string2ids(string) + [self.vocab.eos_id]
            return torch.tensor(ids, dtype=torch.long)

        batch_reply = [{'id': self.getID(), 'text': '', 'text_candidates': []} for _ in range(len(observations))]
        valid_ids = [i for i, obs in enumerate(observations) if is_valid_history(obs['agent'].history)]
        batch_size = len(valid_ids)

        if batch_size == 0:
            return batch_reply

        try:
            valid_observations = [observations[i] for i in valid_ids]

            infos = [obs['agent'].history['info'][:self.model.n_pos_embeddings-3] for obs in valid_observations]
            infos = [([self.vocab.info_bos_id] + ifo + [self.vocab.info_eos_id] if len(ifo) else ifo) for ifo in infos]
            dialogs = [list(obs['agent'].history['dialog'])[-self.model.n_pos_embeddings+1:] for obs in valid_observations]
            contexts = []

            if max(map(len, infos)) > 0:
                infos = [torch.tensor(i, dtype=torch.long) for i in infos]
                infos = pad_sequence(infos, batch_first=True, padding_value=self.model.padding_idx)
                if self.use_cuda:
                    infos = infos.cuda()
                contexts.append(infos)

            if max(map(len, dialogs)) > 0:
                dialogs = [torch.tensor(d, dtype=torch.long) for d in dialogs]
                dialogs = pad_sequence(dialogs, batch_first=True, padding_value=self.model.padding_idx)
                if self.use_cuda:
                    dialogs = dialogs.cuda()
                contexts.append(dialogs)

            enc_contexts = [self.model.encode(c) for c in contexts]
            pred_texts = self.model.beam_search(enc_contexts)

            for i in range(batch_size):
                pred_text_str, pred_text = self._postprocess_text(pred_texts[i], valid_observations[i]['agent'])

                valid_observations[i]['agent'].history['dialog'].extend([self.vocab.talker2_bos_id] +
                                                                        pred_text +
                                                                        [self.vocab.talker2_eos_id])
                batch_reply[valid_ids[i]]['text'] = pred_text_str
                batch_reply[valid_ids[i]]['episode_done'] = valid_observations[i]['agent'].episode_done

            if self.opt['rank_candidates']:
                candidates = [list(obs.get('label_candidates', [])) for obs in valid_observations]
                lens_candidates = [len(c) for c in candidates]

                if max(lens_candidates) > 0:
                    candidates = [c + ['' for _ in range(max(lens_candidates) - len(c))] for c in candidates]
                    scores = [[] for _ in range(len(candidates))]

                    for i in range(max(lens_candidates)):
                        current_cands = [to_tensor(c[i])[:self.model.n_pos_embeddings-1] for c in candidates]
                        current_cands = pad_sequence(current_cands, batch_first=True, padding_value=self.model.padding_idx)
                        if self.use_cuda:
                            current_cands = current_cands.cuda()

                        logits = self.model.decode(current_cands[:, :-1], enc_contexts)
                        log_probas = F.log_softmax(logits, dim=-1)
                        log_probas = torch.gather(log_probas, -1, current_cands[:, 1:].unsqueeze(-1)).squeeze(-1)
                        log_probas.masked_fill_(current_cands[:, 1:].eq(self.model.padding_idx), 0)

                        current_lens = current_cands[:, 1:].ne(self.model.padding_idx).float().sum(dim=-1)
                        current_scores = log_probas.sum(dim=-1) / current_lens

                        for k, s in enumerate(current_scores):
                            if i < lens_candidates[k]:
                                scores[k].append(s.item())

                    ranked_ids = [sorted(range(len(s)), key=lambda k: s[k], reverse=True) for s in scores]
                    ranked_strings = [[c[i] for i in ids] for ids, c in zip(ranked_ids, candidates)]

                    for i in range(batch_size):
                        batch_reply[valid_ids[i]]['text_candidates'] = ranked_strings[i]

        except Exception as e:
            # raise e
            print(e)

        return batch_reply

    def share(self):
        shared = super(TransformerAgent, self).share()
        shared['opt'] = self.opt
        shared['model'] = self.model
        shared['retrieval'] = self.retrieval_bot

        return shared

    def reset(self):
        self.history = {'str_info': None, 'str_dialog': deque(DIALOG_SIZE * ['None'], maxlen=DIALOG_SIZE),
                        'info': [], 'dialog': deque(maxlen=self.model.n_pos_embeddings-1)}
        self.episode_done = True
        self.observation = None
        self.reply_checker.clean()
    def __init__(self, opt, shared=None):
        super(TransformerAgent, self).__init__(opt, shared)

        self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available()
        if self.use_cuda:
            torch.cuda.set_device(self.opt['gpu'])

        torch.set_grad_enabled(False)

        model_config = get_model_config()
        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)
        self.reply_checker = ReplyChecker(correct_generative=self.opt['correct_generative'],
                                          split_into_sentences=self.opt['split_into_sentences'])

        self.replace_repeat = self.opt['replace_repeat']
        self.replace_ngram = self.opt['replace_ngram']
        self.ngram_size = self.opt['ngram_size']
        self.detokenize = self.opt['detokenize']
        self.emoji_prob = self.opt['emoji_prob']
        self.add_questions = self.opt['add_questions']
        self.beam_size = self.opt['beam_size']

        self.clean_emoji = self.opt['clean_emoji']
        self.check_grammar = self.opt['check_grammar']

        # 'max_seq_len': 128,
        # 'beam_size': 1,
        # 'diversity_coef': 0,
        # 'diversity_groups': 1,
        # 'annealing_topk': None,
        # 'annealing': 0,
        # 'length_penalty': 0.6,

        if self.opt['annealing_topk'] is not None:
            assert self.opt['annealing_topk'] >= self.opt['beam_size']

        assert self.opt['diversity_coef'] >= 0
        assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0

        if shared is None:
            self.model = TransformerModel(n_layers=model_config.n_layers,
                                          n_embeddings=len(self.vocab),
                                          n_pos_embeddings=model_config.n_pos_embeddings,
                                          embeddings_size=model_config.embeddings_size,
                                          padding_idx=self.vocab.pad_id,
                                          n_heads=model_config.n_heads,
                                          dropout=model_config.dropout,
                                          embed_dropout=model_config.embed_dropout,
                                          attn_dropout=model_config.attn_dropout,
                                          ff_dropout=model_config.ff_dropout,
                                          bos_id=self.vocab.bos_id,
                                          eos_id=self.vocab.eos_id,
                                          max_seq_len=self.opt['max_seq_len'],
                                          beam_size=self.opt['beam_size'],
                                          length_penalty=self.opt['length_penalty'],
                                          n_segments=model_config.n_segments,
                                          sample=self.opt['sample'],
                                          annealing_topk=self.opt['annealing_topk'],
                                          annealing=self.opt['annealing'],
                                          diversity_coef=self.opt['diversity_coef'],
                                          diversity_groups=self.opt['diversity_groups'])
            self.retrieval_bot = RetrievalBot()

            state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage)
            if 'model' in state_dict:
                state_dict = state_dict['model']

            self.model.load_state_dict(state_dict)
            print('Weights loaded from {}'.format(model_config.checkpoint_path))

            if self.use_cuda:
                self.model = self.model.cuda()

            self.model.eval()

        else:
            self.model = shared['model']
            self.retrieval_bot = shared['retrieval']

        self.reset()
Exemple #3
0
def main():
    model_config = get_model_config_dialog()
    test_config = get_test_config_dialog()

    set_seed(test_config.seed)
    device = torch.device(test_config.device)

    vocab = myVocab(model_config.vocab_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    transformer = transformer.to(device)
    state_dict = torch.load(test_config.last_checkpoint_path,
                            map_location=device)
    temp = dict(state_dict['model'])
    keys = list(temp.keys())
    for key in keys:
        # new_key = '.'.join([i for i in key.split('.') if i != 'module'])
        new_key = key.replace('.module', '')
        temp[new_key] = temp.pop(key)
    transformer.load_state_dict(temp)
    transformer.eval()
    print('Weights loaded from {}'.format(test_config.last_checkpoint_path))

    def answer(message):
        message = ' '.join(message)
        message = vocab.string2ids(message)
        message = [vocab.bos_id] + message + [vocab.eos_id]
        message = message[:60]
        # print(message)
        contexts = [
            torch.tensor([c], dtype=torch.long, device=device)
            for c in [message] if len(c) > 0
        ]
        prediction = transformer.predict(contexts)[0]
        prediction_str = vocab.ids2string(prediction)
        return prediction_str

    def answer_beams(message):
        message = ' '.join(message)
        message = vocab.string2ids(message)
        message = [vocab.bos_id] + message + [vocab.eos_id]
        message = message[:30]
        # print(message)
        contexts = [
            torch.tensor([c], dtype=torch.long, device=device)
            for c in [message] if len(c) > 0
        ]
        predictions = transformer.predict_beams(contexts)[0]
        prediction_strs = [
            vocab.ids2string(prediction) for prediction in predictions
        ]
        return prediction_strs

    '''
    with open('data/test200_output_noinit_noweight.txt', 'w', encoding='utf8') as fw:
        with open('data/test200.txt', 'r', encoding='utf8') as fr:
            lines = fr.readlines()
            for line in lines:
                post, response = line.strip('\n').replace(' ', '').split('\t')
                ans = answer(post)
                fw.write('source:' + post + '\t' + 'target:' + response + '\t' + 'answer:' + ans + '\n')
    '''
    '''
    while True:
        message = input('>')
        ans = answer(message)
        print(ans)
    '''

    while True:
        message = input('>')
        ans = answer_beams(message)
        for i in ans:
            print(i)
Exemple #4
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument('--local_rank',
                        type=int,
                        default=-1,
                        help="Distributed training.")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Used for debugging on GPU machine.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Used for debugging on GPU machine.")
    args = parser.parse_args()

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.ERROR)
    logger = logging.getLogger(__file__)
    if args.server_ip and args.server_port and args.local_rank in [-1, 0]:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    model_config = get_model_config()
    trainer_config = get_trainer_config()

    # Log only on main process
    if args.local_rank not in [-1, 0]:
        sys.stdout = open(f"./runs/log_distributed_{args.local_rank}",
                          "w")  # dump sdtout
        writer = DummyWriter()
    else:
        writer = SummaryWriter(comment=trainer_config.writer_comment)

    logger.info("model config: {}".format(model_config))
    logger.info("trainer config: {}".format(trainer_config))
    log_dir = writer.log_dir
    interrupt_checkpoint_path = os.path.join(
        log_dir, trainer_config.interrupt_checkpoint_path)
    last_checkpoint_path = os.path.join(log_dir,
                                        trainer_config.last_checkpoint_path)
    logger.info(
        "Logging to {}".format(log_dir)
    )  # Let's save everything on an experiment in the ./runs/XXX/directory
    if args.local_rank in [-1, 0]:
        with open(os.path.join(log_dir, "model_config.json"), "w") as f:
            json.dump(model_config, f)
        with open(os.path.join(log_dir, "trainer_config.json"), "w") as f:
            json.dump(trainer_config, f)

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path,
                                zero_shot=trainer_config.zero_shot)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        normalize_embeddings=model_config.normalize_embeddings,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        sent_dialog_id=vocab.sent_dialog_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups,
        multiple_choice_head=model_config.multiple_choice_head,
        constant_embedding=model_config.constant_embedding,
        single_input=model_config.single_input,
        dialog_embeddings=model_config.dialog_embeddings,
        share_models=model_config.share_models,
        successive_attention=model_config.successive_attention,
        sparse_embeddings=model_config.sparse_embeddings,
        shared_attention=model_config.shared_attention,
        bs_temperature=model_config.bs_temperature,
        bs_nucleus_p=model_config.bs_nucleus_p,
        vocab=None)  # for beam search debugging

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        if not model_config.share_models:
            load_openai_weights(transformer.encoder_module,
                                trainer_config.openai_parameters_dir,
                                n_special_tokens=vocab.n_special_tokens)
        logger.info('OpenAI weights loaded from {}, model shared: {}'.format(
            trainer_config.openai_parameters_dir, model_config.share_models))

    logger.info('loading datasets')
    train_dataset = FacebookDataset(
        trainer_config.train_datasets,
        vocab,
        max_lengths=(transformer.n_pos_embeddings - 1) //
        (3 if model_config.single_input else 1),  # A bit restrictive here
        dialog_embeddings=model_config.dialog_embeddings,
        cache=trainer_config.train_datasets_cache,
        use_start_end=model_config.use_start_end,
        negative_samples=trainer_config.negative_samples,
        augment=trainer_config.persona_augment,
        aug_syn_proba=trainer_config.persona_aug_syn_proba,
        limit_size=trainer_config.limit_train_size)
    test_dataset = FacebookDataset(
        trainer_config.test_datasets,
        vocab,
        max_lengths=(transformer.n_pos_embeddings - 1) //
        (3 if model_config.single_input else 1),  # A bit restrictive here
        dialog_embeddings=model_config.dialog_embeddings,
        cache=trainer_config.test_datasets_cache,
        use_start_end=model_config.use_start_end,
        negative_samples=-1,  # Keep all negative samples
        augment=False,
        aug_syn_proba=0.0,
        limit_size=trainer_config.limit_eval_size)
    logger.info(
        f'train dataset {len(train_dataset)} test dataset {(test_dataset)}')

    if args.local_rank != -1:

        torch.cuda.set_device(args.local_rank)
        device = torch.device('cuda', args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')
        transformer.distribute(device)

    model_trainer = Trainer(
        transformer,
        train_dataset,
        writer,
        test_dataset,
        train_batch_size=trainer_config.train_batch_size,
        batch_split=trainer_config.batch_split,
        test_batch_size=trainer_config.test_batch_size,
        lr=trainer_config.lr,
        lr_warmup=trainer_config.lr_warmup,
        weight_decay=trainer_config.weight_decay,
        s2s_weight=trainer_config.s2s_weight,
        lm_weight=trainer_config.lm_weight,
        risk_weight=trainer_config.risk_weight,
        hits_weight=trainer_config.hits_weight,
        single_input=model_config.single_input,
        n_jobs=trainer_config.n_jobs,
        clip_grad=trainer_config.clip_grad,
        device=device,
        ignore_idxs=vocab.special_tokens_ids,
        local_rank=args.local_rank,
        apex_level=model_config.apex_level,
        apex_loss_scale=trainer_config.apex_loss_scale,
        linear_schedule=trainer_config.linear_schedule,
        n_epochs=trainer_config.n_epochs,
        evaluate_full_sequences=trainer_config.evaluate_full_sequences)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.load_last, map_location=device)
        model_trainer.load_state_dict(state_dict)
        logger.info('Weights loaded from {}'.format(trainer_config.load_last))

    # helpers -----------------------------------------------------
    def external_metrics_func(full_references,
                              full_predictions,
                              epoch,
                              metric=None):
        references_file_path = os.path.join(
            writer.log_dir,
            trainer_config.eval_references_file + "_{}".format(epoch))
        predictions_file_path = os.path.join(
            writer.log_dir,
            trainer_config.eval_predictions_file + "_{}".format(epoch))
        with open(references_file_path, 'w', encoding='utf-8') as f:
            f.write(unicode('\n'.join(full_references)))
        with open(predictions_file_path, 'w', encoding='utf-8') as f:
            f.write(unicode('\n'.join(full_predictions)))

        if metric is not None:
            return specified_nlp_metric([references_file_path],
                                        predictions_file_path, metric)

        nist, bleu, meteor, entropy, div, avg_len = nlp_metrics(
            [references_file_path], predictions_file_path)

        metrics = {'meteor': meteor, 'avg_len': avg_len}
        for name, metric in (('nist', nist), ('entropy', entropy),
                             ('div', div), ('bleu', bleu)):
            for i, m in enumerate(metric, 1):
                metrics['{}_{}'.format(name, i)] = m

        return metrics

    def save_func(epoch):
        if epoch != -1:
            torch.save(model_trainer.state_dict(), last_checkpoint_path)

    def sample_text_func(epoch):
        n_samples = 0
        model_trainer.model.eval()
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target, _ in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device)
                for c in [persona_info, dialog] if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = vocab.ids2string(persona_info[1:-1])
            dialog_str = vocab.ids2string(dialog)
            dialog_str = dialog_str.replace(vocab.talker1_bos,
                                            '\n\t- ').replace(
                                                vocab.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(vocab.talker1_eos,
                                            '').replace(vocab.talker2_eos, '')
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)

            logger.info('\n')
            logger.info('Persona info:\n\t{}'.format(persona_info_str))
            logger.info('Dialog:{}'.format(dialog_str))
            logger.info('Target:\n\t{}'.format(target_str))
            logger.info('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs, external_metrics_func, epoch)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        assert all([0 <= s <= 1.0 for s in scores])
        return [1 - s for s in scores]

    def get_risk_metric_func(risk_metric):
        """ risk_metric selected in:
            f1, meteor, avg_len, nist_{1, 2, 3, 4}, entropy_{1, 2, 3, 4}, div_{1, 2}, bleu_{1, 2, 3, 4}
        """
        def external_metric_risk(predictions, targets):
            string_targets = list(vocab.ids2string(t) for t in targets)
            string_predictions = list(vocab.ids2string(t) for t in predictions)
            metrics = [
                external_metrics_func([t], [p], epoch=-1, metric=risk_metric)
                for p, t in zip(string_predictions, string_targets)
            ]

            if any([s in risk_metric for s in ['entropy', 'nist', 'avg_len']]):
                return [-m for m in metrics]

            assert all([0 <= s <= 1.0 for s in metrics]), metrics

            return [1 - m for m in metrics]

        if risk_metric == 'f1':
            return f1_risk

        return external_metric_risk

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(
            after_epoch_funcs=[save_func, sample_text_func, test_func],
            risk_func=get_risk_metric_func(trainer_config.risk_metric))
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        if args.local_rank in [-1, 0]:
            torch.save(model_trainer.state_dict(), interrupt_checkpoint_path)
        raise e
def main():
    model_config = get_model_config()
    trainer_config = get_trainer_config()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        print('OpenAI weights loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = FacebookDataset(trainer_config.train_datasets, vocab,
                                    transformer.n_pos_embeddings - 1)
    test_dataset = FacebookDataset(trainer_config.test_datasets, vocab,
                                   transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(transformer,
                            train_dataset,
                            test_dataset,
                            batch_size=trainer_config.batch_size,
                            batch_split=trainer_config.batch_split,
                            lr=trainer_config.lr,
                            lr_warmup=trainer_config.lr_warmup,
                            lm_weight=trainer_config.lm_weight,
                            risk_weight=trainer_config.risk_weight,
                            n_jobs=trainer_config.n_jobs,
                            clip_grad=trainer_config.clip_grad,
                            device=device,
                            ignore_idxs=vocab.special_tokens_ids)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path,
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path))

    # helpers -----------------------------------------------------
    def save_func(epoch):
        torch.save(model_trainer.state_dict(),
                   trainer_config.last_checkpoint_path)

    def sample_text_func(epoch):
        n_samples = 5
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for persona_info, dialog, target in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device)
                for c in [persona_info, dialog] if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]

            persona_info_str = vocab.ids2string(persona_info[1:-1])
            dialog_str = vocab.ids2string(dialog)
            dialog_str = dialog_str.replace(vocab.talker1_bos,
                                            '\n\t- ').replace(
                                                vocab.talker2_bos, '\n\t- ')
            dialog_str = dialog_str.replace(vocab.talker1_eos,
                                            '').replace(vocab.talker2_eos, '')
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)

            print('\n')
            print('Persona info:\n\t{}'.format(persona_info_str))
            print('Dialog:{}'.format(dialog_str))
            print('Target:\n\t{}'.format(target_str))
            print('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        return [1 - s for s in scores]

    # helpers -----------------------------------------------------

    try:
        model_trainer.train(
            trainer_config.n_epochs,
            after_epoch_funcs=[save_func, sample_text_func, test_func],
            risk_func=f1_risk)
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        torch.save(model_trainer.state_dict(),
                   trainer_config.interrupt_checkpoint_path)
        raise e
def main():
    model_config = get_model_config_dialog()
    trainer_config = get_trainer_config_dialog()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)
    # zrs
    parser = argparse.ArgumentParser()
    parser.add_argument("--local_rank", type=int, default=-1)
    args = parser.parse_args()
    distributed = (args.local_rank != -1)
    if distributed:
        print(args.local_rank)
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl',
                                             init_method='env://')

    vocab = myVocab(model_config.vocab_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        temperature=model_config.temperature,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        openai_model = torch.load(trainer_config.openai_parameters_dir,
                                  map_location=device)
        openai_model.pop('decoder.pre_softmax.weight')
        b = list(openai_model.keys())
        for i in b:
            temp = i.split('.')
            keep = True
            for j in range(model_config.n_layers, 12):
                if str(j) in temp:
                    keep = False
                    break
            if keep:
                openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
            else:
                print(i)
                openai_model.pop(i)
            #openai_model[i.split('.', 1)[1]] = openai_model.pop(i)
        transformer.transformer_module.load_state_dict(openai_model,
                                                       strict=True)
        # load_openai_weights_chinese(transformer.transformer_module, trainer_config.openai_parameters_dir)
        print('OpenAI weights chinese loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = S2sDataset_dialog(trainer_config.train_datasets, vocab,
                                      transformer.n_pos_embeddings - 1)
    test_dataset = S2sDataset_dialog(trainer_config.test_datasets, vocab,
                                     transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(
        transformer,
        train_dataset,
        test_dataset,
        batch_size=trainer_config.batch_size,
        batch_split=trainer_config.batch_split,
        lr=trainer_config.lr,
        lr_warmup=trainer_config.lr_warmup,
        lm_weight=trainer_config.lm_weight,
        risk_weight=trainer_config.risk_weight,
        n_jobs=trainer_config.n_jobs,
        clip_grad=trainer_config.clip_grad,
        # label_smoothing=trainer_config.label_smoothing,
        device=device,
        ignore_idxs=vocab.special_tokens_ids,
        distributed=distributed)
    if distributed:
        model_trainer.model.transformer_module = DistributedDataParallel(
            model_trainer.model.transformer_module,
            device_ids=[args.local_rank],
            output_device=args.local_rank)

    start_epoch = 0
    init_epoch = 0

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path +
                                str(init_epoch - 1),
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        # start_epoch = int(cop.sub('', trainer_config.last_checkpoint_path.split('/')[-1])) + 1
        start_epoch = init_epoch
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path + str(init_epoch - 1)))

    # helpers -----------------------------------------------------
    def save_func(epoch):
        dirs = '/'.join(trainer_config.last_checkpoint_path.split('/')[:-1])
        if not os.path.exists(dirs):
            os.makedirs(dirs)
        torch.save(model_trainer.state_dict(),
                   trainer_config.last_checkpoint_path)
        torch.save(model_trainer.state_dict(),
                   trainer_config.last_checkpoint_path + str(epoch))
        if os.path.exists(trainer_config.last_checkpoint_path +
                          str(epoch - 100)):
            os.remove(trainer_config.last_checkpoint_path + str(epoch - 100))

    def sample_text_func(epoch):
        n_samples = 5
        samples_idxs = random.sample(range(len(test_dataset)), n_samples)
        samples = [test_dataset[idx] for idx in samples_idxs]
        for source, target in samples:
            contexts = [
                torch.tensor([c],
                             dtype=torch.long,
                             device=model_trainer.device) for c in [source]
                if len(c) > 0
            ]
            prediction = model_trainer.model.predict(contexts)[0]
            source_str = vocab.ids2string(source)
            target_str = vocab.ids2string(target[1:-1])
            prediction_str = vocab.ids2string(prediction)
            print('\n')
            print('Source:{}'.format(source_str))
            print('Target:\n\t{}'.format(target_str))
            print('Prediction:\n\t{}'.format(prediction_str))

    def test_func(epoch):
        if (epoch + 1) % trainer_config.test_period == 0:
            metric_funcs = {'f1_score': f1_score}
            model_trainer.test(metric_funcs)

    def f1_risk(predictions, targets):
        scores = f1_score(predictions, targets, average=False)
        return [1 - s for s in scores]

    # helpers -----------------------------------------------------

    # model_trainer.model.transformer_module = nn.DataParallel(model_trainer.model.transformer_module, device_ids=[0, 1])
    try:
        if args.local_rank in [-1, 0]:
            model_trainer.train(
                start_epoch,
                trainer_config.n_epochs,
                after_epoch_funcs=[save_func, sample_text_func, test_func],
                risk_func=f1_risk)
        else:
            model_trainer.train(start_epoch, trainer_config.n_epochs)
    except (KeyboardInterrupt, Exception, RuntimeError) as e:
        torch.save(model_trainer.state_dict(),
                   trainer_config.interrupt_checkpoint_path)
        raise e
    def __init__(self, opt, shared=None):
        super(TransformerAgent, self).__init__(opt, shared)

        self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available()
        if self.use_cuda:
            torch.cuda.set_device(self.opt['gpu'])

        torch.set_grad_enabled(False)

        model_config = get_model_config()
        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)

        self.dialog_embeddings = model_config.dialog_embeddings
        self.use_start_end = model_config.use_start_end
        self.single_input = model_config.single_input
        self.apex_level = model_config.apex_level

        # 'max_seq_len': 128,
        # 'beam_size': 1,
        # 'diversity_coef': 0,
        # 'diversity_groups': 1,
        # 'annealing_topk': None,
        # 'annealing': 0,
        # 'length_penalty': 0.6,

        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)

        if self.opt['annealing_topk'] is not None:
            assert self.opt['annealing_topk'] > self.opt['beam_size']

        assert self.opt['diversity_coef'] >= 0
        assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0

        if shared is None:
            self.model = TransformerModel(n_layers=model_config.n_layers,
                                          n_embeddings=len(self.vocab),
                                          n_pos_embeddings=model_config.n_pos_embeddings,
                                          embeddings_size=model_config.embeddings_size,
                                          padding_idx=self.vocab.pad_id,
                                          n_heads=model_config.n_heads,
                                          dropout=model_config.dropout,
                                          embed_dropout=model_config.embed_dropout,
                                          attn_dropout=model_config.attn_dropout,
                                          ff_dropout=model_config.ff_dropout,
                                          bos_id=self.vocab.bos_id,
                                          eos_id=self.vocab.eos_id,
                                          sent_dialog_id=self.vocab.sent_dialog_id,
                                          max_seq_len=self.opt['max_seq_len'],
                                          beam_size=self.opt['beam_size'],
                                          length_penalty=self.opt['length_penalty'],
                                          n_segments=model_config.n_segments,
                                          sample=self.opt['sample'],
                                          annealing_topk=self.opt['annealing_topk'],
                                          annealing=self.opt['annealing'],
                                          diversity_coef=self.opt['diversity_coef'],
                                          diversity_groups=self.opt['diversity_groups'],
                                          normalize_embeddings=model_config.normalize_embeddings,
                                          multiple_choice_head=model_config.multiple_choice_head,
                                          constant_embedding=model_config.constant_embedding,
                                          vocab=self.vocab,
                                          single_input=model_config.single_input,
                                          dialog_embeddings=model_config.dialog_embeddings,
                                          share_models=model_config.share_models,
                                          successive_attention=model_config.successive_attention,
                                          sparse_embeddings=model_config.sparse_embeddings,
                                          shared_attention=model_config.sparse_embeddings,
                                          bs_temperature=model_config.bs_temperature,
                                          bs_nucleus_p=model_config.bs_nucleus_p
                                          )

            state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage)
            if 'model' in state_dict:
                state_dict = state_dict['model']

            self.model.load_state_dict(state_dict)
            print('Weights loaded from {}'.format(model_config.checkpoint_path))

            if self.use_cuda:
                self.model = self.model.cuda()

            self.model.eval()

            self.model = apex_model(self.model, apex_level=self.apex_level)

        else:
            self.model = shared['model']

        self.reset()
class TransformerAgent(Agent):
    @staticmethod
    def add_cmdline_args(argparser):
        agent_args = argparser.add_argument_group('Agent parameters')
        agent_args.add_argument('-gpu', '--gpu', type=int, default=-1, 
                                help='which GPU to use')
        agent_args.add_argument('--no-cuda', type=bool, default=False,
                                help='disable GPUs even if available. otherwise, will use GPUs if '
                                     'available on the device.')
        agent_args.add_argument('--rank_candidates', type=bool, default=False,
                                help='Whether the model should parse candidates for ranking.')
        agent_args.add_argument('--sample', type=bool, default=False,
                                help='Sampling of beam from beam search')
        agent_args.add_argument('--max_seq_len', type=int, default=128,
                                help='')
        agent_args.add_argument('--beam_size', type=int, default=1,
                                help='')
        agent_args.add_argument('--diversity_coef', type=float, default=0,
                                help='')
        agent_args.add_argument('--diversity_groups', type=int, default=1,
                                help='')
        agent_args.add_argument('--annealing_topk', type=float, default=None,
                                help='')
        agent_args.add_argument('--annealing', type=float, default=0.0,
                                help='')
        agent_args.add_argument('--length_penalty', type=float, default=0.6,
                                help='')
        
        return argparser

    def __init__(self, opt, shared=None):
        super(TransformerAgent, self).__init__(opt, shared)

        self.use_cuda = not self.opt.get('no_cuda') and torch.cuda.is_available()
        if self.use_cuda:
            torch.cuda.set_device(self.opt['gpu'])

        torch.set_grad_enabled(False)

        model_config = get_model_config()
        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)

        self.dialog_embeddings = model_config.dialog_embeddings
        self.use_start_end = model_config.use_start_end
        self.single_input = model_config.single_input
        self.apex_level = model_config.apex_level

        # 'max_seq_len': 128,
        # 'beam_size': 1,
        # 'diversity_coef': 0,
        # 'diversity_groups': 1,
        # 'annealing_topk': None,
        # 'annealing': 0,
        # 'length_penalty': 0.6,

        self.vocab = BPEVocab.from_files(model_config.bpe_vocab_path, model_config.bpe_codes_path)

        if self.opt['annealing_topk'] is not None:
            assert self.opt['annealing_topk'] > self.opt['beam_size']

        assert self.opt['diversity_coef'] >= 0
        assert self.opt['beam_size'] % self.opt['diversity_groups'] == 0

        if shared is None:
            self.model = TransformerModel(n_layers=model_config.n_layers,
                                          n_embeddings=len(self.vocab),
                                          n_pos_embeddings=model_config.n_pos_embeddings,
                                          embeddings_size=model_config.embeddings_size,
                                          padding_idx=self.vocab.pad_id,
                                          n_heads=model_config.n_heads,
                                          dropout=model_config.dropout,
                                          embed_dropout=model_config.embed_dropout,
                                          attn_dropout=model_config.attn_dropout,
                                          ff_dropout=model_config.ff_dropout,
                                          bos_id=self.vocab.bos_id,
                                          eos_id=self.vocab.eos_id,
                                          sent_dialog_id=self.vocab.sent_dialog_id,
                                          max_seq_len=self.opt['max_seq_len'],
                                          beam_size=self.opt['beam_size'],
                                          length_penalty=self.opt['length_penalty'],
                                          n_segments=model_config.n_segments,
                                          sample=self.opt['sample'],
                                          annealing_topk=self.opt['annealing_topk'],
                                          annealing=self.opt['annealing'],
                                          diversity_coef=self.opt['diversity_coef'],
                                          diversity_groups=self.opt['diversity_groups'],
                                          normalize_embeddings=model_config.normalize_embeddings,
                                          multiple_choice_head=model_config.multiple_choice_head,
                                          constant_embedding=model_config.constant_embedding,
                                          vocab=self.vocab,
                                          single_input=model_config.single_input,
                                          dialog_embeddings=model_config.dialog_embeddings,
                                          share_models=model_config.share_models,
                                          successive_attention=model_config.successive_attention,
                                          sparse_embeddings=model_config.sparse_embeddings,
                                          shared_attention=model_config.sparse_embeddings,
                                          bs_temperature=model_config.bs_temperature,
                                          bs_nucleus_p=model_config.bs_nucleus_p
                                          )

            state_dict = torch.load(model_config.checkpoint_path, map_location=lambda storage, loc: storage)
            if 'model' in state_dict:
                state_dict = state_dict['model']

            self.model.load_state_dict(state_dict)
            print('Weights loaded from {}'.format(model_config.checkpoint_path))

            if self.use_cuda:
                self.model = self.model.cuda()

            self.model.eval()

            self.model = apex_model(self.model, apex_level=self.apex_level)

        else:
            self.model = shared['model']

        self.reset()

    def _parse(self, text):
        persona_info = []
        dialog = []

        for subtext in text.split('\n'):
            subtext = subtext.strip()
            
            if subtext.startswith('your persona:'):
                subtext = subtext.replace('your persona:', '').strip()
                persona_info.append(subtext)
            else:
                dialog.append(subtext)

        return persona_info, dialog

    def _process_info(self, info):
        info = self._add_start_end(info[:self.model.n_pos_embeddings - (2 if self.use_start_end else 0)],
                                   self.vocab.info_bos_id,
                                   self.vocab.info_eos_id)
        info = self._add_dialog_embeddings(info, self.vocab.info_dialog_id)

        return info

    def _process_1st_replica(self, replica):
        replica = self._add_start_end(replica, self.vocab.talker1_bos_id, self.vocab.talker1_eos_id)
        replica = self._add_dialog_embeddings(replica, self.vocab.talker1_dialog_id)
        return replica

    def _process_2nd_replica(self, replica):
        replica = self._add_start_end(replica, self.vocab.talker2_bos_id, self.vocab.talker2_eos_id)
        replica = self._add_dialog_embeddings(replica, self.vocab.talker2_dialog_id)
        return replica

    def _add_dialog_embeddings(self, toks, dialog_tok):
        if self.dialog_embeddings:
            toks = [[t, dialog_tok] for t in toks]
        return toks

    def _add_start_end(self, toks, start, end):
        if self.use_start_end:
            toks = [start] + toks + [end]
        return toks

    def observe(self, observation):
        if self.episode_done:
            self.reset()

        if 'text' in observation:
            text = observation['text']
            info, dialog = self._parse(text)

            info = sum([self.vocab.string2ids(i) for i in info], [])
            if info:
                prev_info = [h[0] for h in self.history['info']] if self.dialog_embeddings else self.history['info']
                self.history['info'] = self._process_info(prev_info[1:-1] + info)

            for i, replica in enumerate(dialog, 1):
                replica = self.vocab.string2ids(replica)
                replica = self._process_1st_replica(replica) if i % 2 == 1 else self._process_2nd_replica(replica)
                self.history['dialog'].extend(replica)

        observation['agent'] = self        

        self.episode_done = observation['episode_done']
        self.observation = observation
        
        return observation
    
    def act(self):
        return self.batch_act([self.observation])[0]

    def batch_act(self, observations):
        def is_valid_history(history):
            return len(history['dialog'])

        def to_tensor(string):
            ids = [self.vocab.bos_id] + self.vocab.string2ids(string) + [self.vocab.eos_id]
            ids = self._add_dialog_embeddings(ids, self.vocab.sent_dialog_id)
            return torch.tensor(ids, dtype=torch.long)

        def to_cuda(data):
            if not self.use_cuda:
                return data

            if isinstance(data, (list, tuple, map)):
                return list(map(lambda x: x.cuda(), data))

            return data.cuda()

        batch_reply = [{'id': self.getID(), 'text': '', 'text_candidates': []} for _ in range(len(observations))]
        valid_ids = [i for i, obs in enumerate(observations) if is_valid_history(obs['agent'].history)]
        batch_size = len(valid_ids)

        if batch_size == 0:
            return batch_reply

        try:
            valid_observations = [observations[i] for i in valid_ids]

            infos = [obs['agent'].history['info'] for obs in valid_observations]
            dialogs = [list(obs['agent'].history['dialog'])[-self.model.n_pos_embeddings+1:] for obs in valid_observations]
            contexts = []

            if max(map(len, infos)) > 0:
                infos = [torch.tensor(i, dtype=torch.long) for i in infos]
                contexts.append(infos)

            if max(map(len, dialogs)) > 0:
                dialogs = [torch.tensor(d, dtype=torch.long) for d in dialogs]
                contexts.append(dialogs)

            if self.single_input:
                contexts = [torch.cat(c, dim=0) for c in zip(*contexts)]
                raw_context = contexts if self.opt['rank_candidates'] else None
                contexts = pad_sequence(contexts, batch_first=True, padding_value=self.model.padding_idx, left=True)
            else:
                contexts = map(lambda x: pad_sequence(x, batch_first=True, padding_value=self.model.padding_idx),
                               contexts)

            contexts = to_cuda(contexts)

            pred_texts = self.model.predict(contexts)

            for i in range(batch_size):
                pred_toks = self._process_2nd_replica(pred_texts[i])
                valid_observations[i]['agent'].history['dialog'].extend(pred_toks)
                batch_reply[valid_ids[i]]['text'] = self.vocab.ids2string(pred_texts[i])
                batch_reply[valid_ids[i]]['episode_done'] = valid_observations[i]['agent'].episode_done

            if self.opt['rank_candidates']:
                enc_contexts = [self.model.encode(c) for c in contexts] if not self.single_input else []

                candidates = [list(obs.get('label_candidates', [])) for obs in valid_observations]
                lens_candidates = [len(c) for c in candidates]

                if max(lens_candidates) > 0:
                    candidates = [c + ['' for _ in range(max(lens_candidates) - len(c))] for c in candidates]
                    scores = [[] for _ in range(len(candidates))]

                    for i in range(max(lens_candidates)):
                        current_cands = [to_tensor(c[i])[:self.model.n_pos_embeddings-1] for c in candidates]

                        lens = map(lambda x: x.size(0), current_cands) if self.single_input else None
                        if self.single_input:
                            lens = map(lambda x: x.size(0), current_cands)
                            current_cands = [torch.cat(c, dim=0)[-self.model.n_pos_embeddings:]
                                             for c in zip(raw_context, current_cands)]

                        current_cands = to_cuda(current_cands)
                        current_cands = pad_sequence(current_cands, batch_first=True,
                                                     padding_value=self.model.padding_idx)

                        logits = self.model.decode(current_cands[:, :-1], enc_contexts)

                        if current_cands.dim() == 3:
                            current_cands = current_cands[:, :, 0]

                        log_probas = F.log_softmax(logits, dim=-1)
                        log_probas = torch.gather(log_probas, -1, current_cands[:, 1:].unsqueeze(-1)).squeeze(-1)

                        if self.single_input:
                            # zero context
                            for j, l in enumerate(lens):
                                current_cands[j, :-l+1] = self.model.padding_idx

                        log_probas.masked_fill_(current_cands[:, 1:].eq(self.model.padding_idx), 0)

                        current_lens = current_cands[:, 1:].ne(self.model.padding_idx).float().sum(dim=-1)
                        current_scores = log_probas.sum(dim=-1) / current_lens

                        for k, s in enumerate(current_scores):
                            if i < lens_candidates[k]:
                                scores[k].append(s.item())

                    ranked_ids = [sorted(range(len(s)), key=lambda k: s[k], reverse=True) for s in scores]
                    ranked_strings = [[c[i] for i in ids] for ids, c in zip(ranked_ids, candidates)]

                    for i in range(batch_size):
                        batch_reply[valid_ids[i]]['text_candidates'] = ranked_strings[i]

        except Exception as e:
            # raise e
            print(e)

        return batch_reply

    def share(self):
        shared = super(TransformerAgent, self).share()
        shared['opt'] = self.opt
        shared['model'] = self.model

        return shared

    def reset(self):
        self.history = {'info': [], 'dialog': deque(maxlen=self.model.n_pos_embeddings-1)}
        self.episode_done = True
        self.observation = None
Exemple #9
0
def get_trainer():
    model_config = get_model_config()
    trainer_config = get_trainer_config()

    set_seed(trainer_config.seed)
    device = torch.device(trainer_config.device)

    vocab = BPEVocab.from_files(model_config.bpe_vocab_path,
                                model_config.bpe_codes_path)

    transformer = TransformerModel(
        n_layers=model_config.n_layers,
        n_embeddings=len(vocab),
        n_pos_embeddings=model_config.n_pos_embeddings,
        embeddings_size=model_config.embeddings_size,
        padding_idx=vocab.pad_id,
        n_heads=model_config.n_heads,
        dropout=model_config.dropout,
        embed_dropout=model_config.embed_dropout,
        attn_dropout=model_config.attn_dropout,
        ff_dropout=model_config.ff_dropout,
        bos_id=vocab.bos_id,
        eos_id=vocab.eos_id,
        max_seq_len=model_config.max_seq_len,
        beam_size=model_config.beam_size,
        length_penalty=model_config.length_penalty,
        n_segments=model_config.n_segments,
        annealing_topk=model_config.annealing_topk,
        annealing=model_config.annealing,
        diversity_coef=model_config.diversity_coef,
        diversity_groups=model_config.diversity_groups)

    if not trainer_config.load_last:
        load_openai_weights(transformer.transformer_module,
                            trainer_config.openai_parameters_dir,
                            n_special_tokens=vocab.n_special_tokens)
        print('OpenAI weights loaded from {}'.format(
            trainer_config.openai_parameters_dir))

    train_dataset = FacebookDataset(trainer_config.train_datasets, vocab,
                                    transformer.n_pos_embeddings - 1)
    test_dataset = FacebookDataset(trainer_config.test_datasets, vocab,
                                   transformer.n_pos_embeddings - 1)

    model_trainer = Trainer(transformer,
                            train_dataset,
                            test_dataset,
                            batch_size=trainer_config.batch_size,
                            batch_split=trainer_config.batch_split,
                            lr=trainer_config.lr,
                            lr_warmup=trainer_config.lr_warmup,
                            lm_weight=trainer_config.lm_weight,
                            risk_weight=trainer_config.risk_weight,
                            n_jobs=trainer_config.n_jobs,
                            clip_grad=trainer_config.clip_grad,
                            device=device,
                            ignore_idxs=vocab.special_tokens_ids)

    if trainer_config.load_last:
        state_dict = torch.load(trainer_config.last_checkpoint_path,
                                map_location=device)
        model_trainer.load_state_dict(state_dict)
        print('Weights loaded from {}'.format(
            trainer_config.last_checkpoint_path))

    return model_trainer