Ejemplo n.º 1
0
    def __init__(self, model, discriminator, loss, optimizer, translator,
                 logger, opt, training_data, validation_data, src_vocab,
                 tgt_vocab):
        self.model = model
        self.discriminator = discriminator
        self.loss = loss
        self.rl_loss = NLLLoss(opt, do_reduce=False)
        self.optimizer = optimizer
        self.translator = translator
        self.logger = logger
        self.opt = opt
        self.tgt_vocab = tgt_vocab

        self.training_data = training_data
        self.validation_data = validation_data

        self.separate = opt.answer == 'sep'
        self.answer = opt.answer == 'enc'
        self.sep_id = src_vocab.lookup(
            Constants.SEP_WORD) if self.separate else Constants.SEP

        self.is_attn_mask = True if opt.defined_slf_attn_mask else False

        self.cntBatch, self.best_metric, self.best_ppl, self.best_bleu = 0, 0, math.exp(
            100), 0
Ejemplo n.º 2
0
def main(opt, logger):
    logger.info('My PID is {0}'.format(os.getpid()))
    logger.info('PyTorch version: {0}'.format(str(torch.__version__)))
    logger.info(opt)

    if torch.cuda.is_available() and not opt.gpus:
        logger.info("WARNING: You have a CUDA device, so you should probably run with -gpus 0")
    if opt.seed > 0:
        torch.manual_seed(opt.seed)
    if opt.gpus:
        if opt.cuda_seed > 0:
            torch.cuda.manual_seed(opt.cuda_seed)
        # cuda.set_device(opt.gpus[0])
    logger.info('My seed is {0}'.format(torch.initial_seed()))
    logger.info('My cuda seed is {0}'.format(torch.cuda.initial_seed()))
    
    ###### ==================== Loading Dataset ==================== ######
    data = torch.load(opt.data)
    vocabularies = data['dict']
    if isinstance(vocabularies['src'], str):
        assert vocabularies['src'] == opt.pretrained
        options = {'transf':True, 'separate':False, 'tgt':False}
        vocabularies['src'] = Vocab.from_opt(pretrained=opt.pretrained, opt=options)
    train_data, valid_data = data['train'], data['valid']

    ### ===== load pre-trained vocabulary ===== ###
    if opt.pre_trained_vocab:
        if not opt.pretrained:
            opt.pre_trained_src_emb = vocabularies['pre-trained']['src']
        opt.pre_trained_tgt_emb = vocabularies['pre-trained']['tgt']
    
    ### ===== wrap datasets ===== ###
    attn_mask_file = '' if not opt.defined_slf_attn_mask else opt.defined_slf_attn_mask + '.train.npy'
    pad_id = vocabularies['src'].lookup('<|endoftext|>') if opt.pretrained.count('gpt2') else Constants.PAD
    trainData = DialogueDataset(train_data, opt.batch_size, copy=opt.copy, 
                                attn_mask_file=attn_mask_file, 
                                opt_cuda=opt.gpus, pad=pad_id)
    validData = DialogueDataset(valid_data, opt.eval_batch_size, copy=opt.copy, 
                                attn_mask_file=attn_mask_file,
                                opt_cuda=opt.gpus, pad=pad_id)
    
    opt.src_vocab_size, opt.tgt_vocab_size = vocabularies['src'].size, vocabularies['tgt'].size
    
    logger.info(' * vocabulary size. source = %d; target = %d' % (opt.src_vocab_size, opt.tgt_vocab_size))
    logger.info(' * number of training batches. %d' % len(trainData))
    logger.info(' * maximum batch size. %d' % opt.batch_size)
    
    ##### =================== Prepare Model =================== #####
    separate = -1
    device = torch.device('cuda:' + str(opt.gpus[0]) if len(opt.gpus) else 'cpu')
    checkpoint = torch.load(opt.checkpoint) if opt.checkpoint else None
    model, parameters_cnt = build_dialogue_model(opt, device, separate=separate, checkpoint=checkpoint)
    logger.info(' * Number of parameters to learn = %d' % parameters_cnt)

    ##### ==================== Prepare Optimizer ==================== #####
    optimizer = Optimizer.from_opt(model, opt)

    ##### ==================== Prepare Loss ==================== #####
    weight = torch.ones(opt.tgt_vocab_size)
    weight[Constants.PAD] = 0
    loss = NLLLoss(opt, weight=weight, size_average=False)
    if opt.gpus:
        cuda.set_device(opt.gpus[0])
        loss.cuda()
        
    ##### ==================== Prepare Translator ==================== #####
    forward_translator = DialogueTranslator(opt, vocabularies['tgt'], data['valid']['tokens'], vocabularies['src'])
    backward_translator = DialogueTranslator(opt, vocabularies['src'], data['valid']['tokens'], vocabularies['tgt'], reverse=True)
    
    # torch.save(opt, opt.save_model + '-opt.pt')
    # import ipdb; ipdb.set_trace()
    ##### ==================== Training ==================== #####
    trainer = DialogueSupervisedTrainer(model, loss, optimizer, 
                                        forward_translator, backward_translator,
                                        logger, opt, trainData, validData)
    trainer.train(device)
Ejemplo n.º 3
0
def main(opt, logger):
    logger.info('My PID is {0}'.format(os.getpid()))
    logger.info('PyTorch version: {0}'.format(str(torch.__version__)))
    logger.info(opt)

    if torch.cuda.is_available() and not opt.gpus:
        logger.info("WARNING: You have a CUDA device, so you should probably run with -gpus 0")
    if opt.seed > 0:
        torch.manual_seed(opt.seed)
    if opt.gpus:
        if opt.cuda_seed > 0:
            torch.cuda.manual_seed(opt.cuda_seed)
        # cuda.set_device(opt.gpus[0])
    logger.info('My seed is {0}'.format(torch.initial_seed()))
    logger.info('My cuda seed is {0}'.format(torch.cuda.initial_seed()))
    
    ###### ==================== Loading Dataset ==================== ######
    data = torch.load(opt.data)
    vocabularies = data['dict']
    if isinstance(vocabularies['src'], str):
        assert vocabularies['src'] == opt.pretrained
        sep = True if opt.answer == 'sep' else False
        options = {'transf':opt.answer != 'enc', 'separate':sep, 'tgt':False}
        vocabularies['src'] = Vocab.from_opt(pretrained=opt.pretrained, opt=options)
    train_data, valid_data = data['train'], data['valid']

    ### ===== load pre-trained vocabulary ===== ###
    if opt.pre_trained_vocab:
        if not opt.pretrained:
            opt.pre_trained_src_emb = vocabularies['pre-trained']['src']
        opt.pre_trained_tgt_emb = vocabularies['pre-trained']['tgt']
        if opt.answer == 'enc':
            opt.pre_trained_ans_emb = vocabularies['pre-trained']['ans']
    
    ### ===== wrap datasets ===== ###
    attn_mask_file = '' if not opt.defined_slf_attn_mask else opt.defined_slf_attn_mask + '.train.npy'
    pad_id = vocabularies['src'].lookup('<|endoftext|>') if opt.pretrained.count('gpt2') else Constants.PAD
    trainData = Dataset(train_data, opt.batch_size, copy=opt.copy, 
                        answer=opt.answer == 'enc', ans_feature=opt.ans_feature, 
                        feature=opt.feature, attn_mask_file=attn_mask_file,
                        opt_cuda=opt.gpus, pad=pad_id)
    validData = Dataset(valid_data, opt.eval_batch_size, copy=opt.copy, 
                        answer=opt.answer == 'enc', ans_feature=opt.ans_feature, 
                        feature=opt.feature, attn_mask_file=attn_mask_file,
                        opt_cuda=opt.gpus, pad=pad_id)
    
    opt.src_vocab_size = vocabularies['src'].size
    opt.tgt_vocab_size = vocabularies['tgt'].size
    opt.feat_vocab = [fv.size for fv in vocabularies['feature']] if opt.feature else None
    opt.ans_feat_vocab = [fv.size for fv in vocabularies['ans_feature']] if opt.ans_feature else None

    logger.info(' * vocabulary size. source = %d; target = %d' % (opt.src_vocab_size, opt.tgt_vocab_size))
    logger.info(' * number of training batches. %d' % len(trainData))
    logger.info(' * maximum batch size. %d' % opt.batch_size)

    ##### =================== Prepare Model =================== #####
    separate = vocabularies['src'].lookup(Constants.SEP_WORD) if opt.answer == 'sep' else -1
    device = torch.device('cuda:' + str(opt.gpus[0]) if len(opt.gpus) else 'cpu')
    checkpoint = torch.load(opt.checkpoint) if opt.checkpoint else None
    if opt.rl:
        rl_device = [torch.device('cuda:' + str(gpu)) for gpu in opt.rl_gpu]
        rl_device = {k:v for k, v in zip(opt.rl, rl_device)}
        opt.rl_device = rl_device
        discriminator = load_rl_model(opt, device, rl_device)
    model, parameters_cnt = build_model(opt, device, separate=separate, checkpoint=checkpoint)
    logger.info(' * Number of parameters to learn = %d' % parameters_cnt)

    ##### ==================== Prepare Optimizer ==================== #####
    optimizer = Optimizer.from_opt(model, opt)

    ##### ==================== Prepare Loss ==================== #####
    weight = torch.ones(opt.tgt_vocab_size)
    weight[Constants.PAD] = 0
    loss = NLLLoss(opt, weight=weight, size_average=False)
    if opt.gpus:
        cuda.set_device(opt.gpus[0])
        loss.cuda()
        
    ##### ==================== Prepare Translator ==================== #####
    translator = Translator(opt, vocabularies['tgt'], data['valid']['tokens'], vocabularies['src'])
    
    ##### ==================== Training ==================== #####
    if opt.rl:
        trainer = RLTrainer(model, discriminator, loss, optimizer, translator, logger, 
                            opt, trainData, validData, vocabularies['src'], vocabularies['tgt'])
    else:
        trainer = SupervisedTrainer(model, loss, optimizer, translator, logger, 
                                    opt, trainData, validData, vocabularies['src'])
    trainer.train(device)
Ejemplo n.º 4
0
def main(opt):
    logging.info('My PID is {0}'.format(os.getpid()))
    logging.info('PyTorch version: {0}'.format(str(torch.__version__)))
    logging.info(opt)

    if torch.cuda.is_available() and not opt.gpus:
        logging.info(
            "WARNING: You have a CUDA device, so you should probably run with -gpus 0"
        )
    if opt.seed > 0:
        torch.manual_seed(opt.seed)
    if opt.gpus:
        if opt.cuda_seed > 0:
            torch.cuda.manual_seed(opt.cuda_seed)
        cuda.set_device(opt.gpus[0])
    logging.info('My seed is {0}'.format(torch.initial_seed()))
    logging.info('My cuda seed is {0}'.format(torch.cuda.initial_seed()))

    ###### ==================== Loading Options ==================== ######
    if opt.checkpoint:
        checkpoint = torch.load(opt.checkpoint)

    ###### ==================== Loading Dataset ==================== ######
    opt.sparse = True if opt.sparse else False
    # logger.info('Loading sequential data ......')
    # sequences = torch.load(opt.sequence_data)
    # seq_vocabularies = sequences['dict']
    # logger.info('Loading structural data ......')
    # graphs = torch.load(opt.graph_data)
    # graph_vocabularies = graphs['dict']

    ### ===== load pre-trained vocabulary ===== ###
    logging.info('Loading sequential data ......')
    sequences = torch.load(opt.sequence_data)
    seq_vocabularies = sequences['dict']
    logging.info('Loading pre-trained vocabulary ......')
    if opt.pre_trained_vocab:
        if not opt.pretrained:
            opt.pre_trained_src_emb = seq_vocabularies['pre-trained']['src']
        opt.pre_trained_tgt_emb = seq_vocabularies['pre-trained']['tgt']
        if opt.answer:
            opt.pre_trained_ans_emb = seq_vocabularies['pre-trained']['src']

    ### ===== wrap datasets ===== ###
    logging.info('Loading Dataset objects ......')
    trainData = torch.load(opt.train_dataset)
    validData = torch.load(opt.valid_dataset)
    trainData.batchSize = validData.batchSize = opt.batch_size
    trainData.numBatches = math.ceil(len(trainData.src) / trainData.batchSize)
    validData.numBatches = math.ceil(len(validData.src) / validData.batchSize)

    logging.info('Preparing vocabularies ......')
    opt.src_vocab_size = seq_vocabularies['src'].size
    opt.tgt_vocab_size = seq_vocabularies['tgt'].size
    opt.feat_vocab = [fv.size for fv in seq_vocabularies['feature']
                      ] if opt.feature else None

    logging.info('Loading structural data ......')
    graphs = torch.load(opt.graph_data)
    graph_vocabularies = graphs['dict']
    del graphs

    opt.edge_vocab_size = graph_vocabularies['edge']['in'].size
    opt.node_feat_vocab = [
        fv.size for fv in graph_vocabularies['feature'][1:-1]
    ] if opt.node_feature else None

    logging.info(' * vocabulary size. source = %d; target = %d' %
                 (opt.src_vocab_size, opt.tgt_vocab_size))
    logging.info(' * number of training batches. %d' % len(trainData))
    logging.info(' * maximum batch size. %d' % opt.batch_size)

    ##### =================== Prepare Model =================== #####
    device = torch.device('cuda' if opt.gpus else 'cpu')
    trainData.device = validData.device = device
    checkpoint = checkpoint if opt.checkpoint else None

    model, parameters_cnt = build_model(opt, device, checkpoint=checkpoint)
    del checkpoint

    logging.info(' * Number of parameters to learn = %d' % parameters_cnt)

    ##### ==================== Prepare Optimizer ==================== #####
    optimizer = Optimizer.from_opt(model, opt)

    ##### ==================== Prepare Loss ==================== #####
    weight = torch.ones(opt.tgt_vocab_size)
    weight[Constants.PAD] = 0
    loss = NLLLoss(opt, weight, size_average=False)
    if opt.gpus:
        loss.cuda()

    ##### ==================== Prepare Translator ==================== #####
    translator = Translator(opt, seq_vocabularies['tgt'],
                            sequences['valid']['tokens'],
                            seq_vocabularies['src'])

    ##### ==================== Training ==================== #####
    trainer = SupervisedTrainer(model, loss, optimizer, translator, opt,
                                trainData, validData, seq_vocabularies['src'],
                                graph_vocabularies['feature'])
    del model
    del trainData
    del validData
    del seq_vocabularies['src']
    del graph_vocabularies['feature']
    trainer.train(device)
Ejemplo n.º 5
0
class RLTrainer(object):
    def __init__(self, model, discriminator, loss, optimizer, translator,
                 logger, opt, training_data, validation_data, src_vocab,
                 tgt_vocab):
        self.model = model
        self.discriminator = discriminator
        self.loss = loss
        self.rl_loss = NLLLoss(opt, do_reduce=False)
        self.optimizer = optimizer
        self.translator = translator
        self.logger = logger
        self.opt = opt
        self.tgt_vocab = tgt_vocab

        self.training_data = training_data
        self.validation_data = validation_data

        self.separate = opt.answer == 'sep'
        self.answer = opt.answer == 'enc'
        self.sep_id = src_vocab.lookup(
            Constants.SEP_WORD) if self.separate else Constants.SEP

        self.is_attn_mask = True if opt.defined_slf_attn_mask else False

        self.cntBatch, self.best_metric, self.best_ppl, self.best_bleu = 0, 0, math.exp(
            100), 0

    def cal_performance(self, loss_input):

        loss = self.loss.cal_loss(loss_input)

        gold, pred = loss_input['gold'], loss_input['pred']

        pred = pred.contiguous().view(-1, pred.size(2))
        pred = pred.max(1)[1]

        gold = gold.contiguous().view(-1)
        non_pad_mask = gold.ne(Constants.PAD)

        n_correct = pred.eq(gold)
        n_correct = n_correct.masked_select(non_pad_mask).sum().item()

        return loss, n_correct, pred

    def cal_rl_loss(self, pred, decoded_text, flu_rl_inputs, rel_rl_inputs,
                    ans_rl_inputs, flu_discriminator, rel_discriminator,
                    ans_discriminator):
        def _get_n_best(logits, n_best_size):
            index_and_score = sorted(enumerate(logits),
                                     key=lambda x: x[1],
                                     reverse=True)

            best_indexes = []
            for i in range(len(index_and_score)):
                if i >= n_best_size:
                    break
                best_indexes.append(index_and_score[i])

            return best_indexes

        def _get_score(start_logits, end_logits, indexes):
            start_logits = _get_n_best(start_logits[indexes[0]:indexes[1]], 5)
            end_logits = _get_n_best(end_logits[indexes[0]:indexes[1]], 5)
            b_scores = [[0, 0]]

            for start in start_logits:
                for end in end_logits:
                    if start[0] <= end[0] and end[0] - start[
                            0] < 64:  # TODO: magic number
                        score = torch.tensor([start[1], end[1]])
                        score = math.pow(score[0].item() * score[1].item(),
                                         0.5)
                        b_scores.append([score, end[0] - start[0] + 1])
            b_scores.sort(key=lambda x: x[0], reverse=True)
            return b_scores[0][0]

        batch_size, seq_length, vocab_size = pred.size()
        gold = decoded_text[:, 1:].contiguous()

        ##=== fluency ===##
        flu_reward = 0
        if 'fluency' in self.opt.rl:
            with torch.no_grad():
                output_pred_dicts = flu_discriminator(
                    flu_rl_inputs[0], attention_mask=flu_rl_inputs[1])

                reward_fct = NLLLoss(self.opt, do_reduce=False)

                lm_loss = reward_fct.cal_simple_nll(output_pred_dicts[0],
                                                    flu_rl_inputs[2])
                lm_loss = lm_loss.view(pred.size(0), -1).mean(-1)
                scores = torch.exp(lm_loss).to(pred.device)

            flu_reward = scores.data.sum().item()
            flu_scores_scale = self.opt.flu_alpha - scores.data

        ##=== relevance ===##
        rel_reward = 0
        if 'relevance' in self.opt.rl:
            with torch.no_grad():
                output = rel_discriminator(rel_rl_inputs[0],
                                           token_type_ids=rel_rl_inputs[1])
                # get the output logits for [CLS]
                logits = output[0].contiguous().to(pred.device)
            scores = torch.softmax(logits, dim=1).transpose(0,
                                                            1)[1].contiguous()

            rel_reward = scores.data.sum().item()
            rel_scores_scale = torch.log(self.opt.rel_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== answerability ===##
        ans_reward = 0
        if 'answerability' in self.opt.rl:
            with torch.no_grad():
                batch_start_logits, batch_end_logits = ans_discriminator(
                    ans_rl_inputs[0], ans_rl_inputs[1], ans_rl_inputs[2])
            scores, rand_scores = [], []
            for b in range(batch_start_logits.size(0)):
                start_logits = torch.softmax(batch_start_logits[b], dim=-1)
                end_logits = torch.softmax(batch_end_logits[b], dim=-1)
                score = _get_score(start_logits.detach().cpu().tolist(),
                                   end_logits.detach().cpu().tolist(),
                                   ans_rl_inputs[3][b])
                scores.append(score)
            scores = torch.tensor(scores, device=pred.device)

            ans_reward = scores.data.sum().item()
            ans_scores_scale = torch.log(self.opt.ans_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== combination ===##
        scores = 0
        if 'fluency' in self.opt.rl:
            scores += flu_scores_scale * self.opt.flu_gamma
        if 'relevance' in self.opt.rl:
            scores += rel_scores_scale * self.opt.rel_gamma
        if 'answerability' in self.opt.rl:
            scores += ans_scores_scale * self.opt.ans_gamma

        n_correct = scores.gt(0).float().sum().item()
        weights = [(batch_size - n_correct) / batch_size,
                   n_correct / batch_size]
        weights = [1 / 3, 2 / 3] if weights[0] > 1 / 3 else weights
        scores_scale_rgt = scores.gt(0).float() * scores * weights[1]
        scores_scale_wrg = scores.lt(0).float() * scores * weights[0]
        scores_scale = scores_scale_rgt + scores_scale_wrg

        log_prb = self.rl_loss.cal_simple_nll(pred.contiguous(),
                                              gold.contiguous()).view(
                                                  batch_size, -1).mean(-1)
        loss = torch.sum(scores_scale * log_prb)

        return loss, [
            n_correct, batch_size, [flu_reward, rel_reward, ans_reward]
        ]

    def save_model(self, better, bleu):
        model_state_dict = self.model.module.state_dict() if len(
            self.opt.gpus) > 1 else self.model.state_dict()
        model_state_dict = collections.OrderedDict([
            (x, y.cpu()) for x, y in model_state_dict.items()
        ])
        checkpoint = {
            'model state dict': model_state_dict,
            'options': self.opt
        }

        if self.opt.save_mode == 'all':
            model_name = self.opt.save_model + '_ppl_{ppl:2.5f}.chkpt'.format(
                ppl=self.best_ppl)
            torch.save(checkpoint, model_name)
        elif self.opt.save_mode == 'best':
            model_name = self.opt.save_model + '.chkpt'
            if better:
                torch.save(checkpoint, model_name)
                print('    - [Info] The checkpoint file has been updated.')

        if bleu != 'unk' and bleu > self.best_bleu:
            self.best_bleu = bleu
            model_name = self.opt.save_model + '_' + str(round(
                bleu * 100, 5)) + '_bleu4.chkpt'
            torch.save(checkpoint, model_name)

    def eval_step(self, device, epoch):
        ''' Epoch operation in evaluation phase '''
        self.model.eval()

        with torch.no_grad():
            max_length = 0
            total_nll_loss, n_word_total, n_word_correct = 0, 0, 0
            total_rl_loss, n_rl_correct = 0, [0, 0]
            flu_reward, rel_reward, ans_reward = 0, 0, 0
            valid_length = len(self.validation_data)
            eval_index_list = range(valid_length)

            for idx in tqdm(eval_index_list,
                            mininterval=2,
                            desc='  - (Validation) ',
                            leave=False):
                batch = self.validation_data[idx]
                inputs, max_length, gold, copy = preprocess_batch(
                    batch,
                    separate=self.separate,
                    enc_rnn=self.opt.enc_rnn != '',
                    dec_rnn=self.opt.dec_rnn != '',
                    feature=self.opt.feature,
                    dec_feature=self.opt.dec_feature,
                    answer=self.answer,
                    ans_feature=self.opt.ans_feature,
                    sep_id=self.sep_id,
                    copy=self.opt.copy,
                    attn_mask=self.is_attn_mask,
                    device=device)
                copy_gold, copy_switch = copy[0], copy[1]

                ### forward ###
                rst = self.model(inputs, max_length=max_length)

                loss_input = {}
                loss_input['pred'], loss_input['gold'] = rst['pred'], gold
                if self.opt.copy:
                    loss_input['copy_pred'], loss_input['copy_gate'] = rst[
                        'copy_pred'], rst['copy_gate']
                    loss_input['copy_gold'], loss_input[
                        'copy_switch'] = copy_gold, copy_switch
                if self.opt.coverage:
                    loss_input['coverage_pred'] = rst['coverage_pred']
                nll_loss, n_correct, _ = self.cal_performance(loss_input)

                rst = self.model(inputs,
                                 max_length=max_length,
                                 rl_type=self.opt.rl)

                flu_rl_inputs, flu_discriminator = None, None
                if 'fluency' in self.opt.rl:
                    flu_rl_inputs = preprocess_rl_batch(
                        rst['decoded_text'], rst['rand_decoded_text'],
                        'fluency', self.tgt_vocab,
                        self.opt.rl_device['fluency'])
                    flu_discriminator = self.discriminator['fluency']
                    flu_discriminator.eval()

                rel_rl_inputs, rel_discriminator = None, None
                if 'relevance' in self.opt.rl:
                    inputing = (inputs['encoder']['src_seq'],
                                rst['decoded_text'],
                                inputs['encoder']['lengths'],
                                rst['rand_decoded_text'])
                    rel_rl_inputs = preprocess_rl_batch(
                        inputing, None, 'relevance', self.tgt_vocab,
                        self.opt.rl_device['relevance'])
                    rel_discriminator = self.discriminator['relevance']
                    rel_discriminator.eval()

                ans_rl_inputs, ans_discriminator = None, None
                if 'answerability' in self.opt.rl:
                    inputing = (inputs['encoder']['src_seq'],
                                rst['decoded_text'],
                                inputs['encoder']['lengths'],
                                rst['rand_decoded_text'])
                    ans_rl_inputs, rand_rl_inputs = preprocess_rl_batch(
                        inputing, None, 'answerability', self.tgt_vocab,
                        self.opt.rl_device['answerability'])
                    ans_discriminator = self.discriminator['answerability']
                    ans_discriminator.eval()

                if self.opt.rl:
                    rl_loss, rl_n_correct = self.cal_rl_loss(
                        rst['pred'], rst['decoded_text'], flu_rl_inputs,
                        rel_rl_inputs, ans_rl_inputs, flu_discriminator,
                        rel_discriminator, ans_discriminator)

                non_pad_mask = gold.ne(Constants.PAD)
                n_word = non_pad_mask.sum().item()

                total_nll_loss += nll_loss.item()
                n_word_total += n_word
                n_word_correct += n_correct

                total_rl_loss += rl_loss.item()
                n_rl_correct[0] += rl_n_correct[0]
                n_rl_correct[1] += rl_n_correct[1]

                flu_reward += rl_n_correct[2][0]
                rel_reward += rl_n_correct[2][1]
                ans_reward += rl_n_correct[2][2]

            loss_per_word = total_nll_loss / n_word_total
            nll_accuracy = n_word_correct / n_word_total
            loss_per_sample = total_rl_loss / n_rl_correct[1]
            bleu = 'unk'
            perplexity = math.exp(min(loss_per_word, 16))
            rl_accuracy = n_rl_correct[0] / n_rl_correct[1]

            flu_reward /= n_rl_correct[1]
            rel_reward /= n_rl_correct[1]
            ans_reward /= n_rl_correct[1]

            if (perplexity <= self.opt.translate_ppl
                    or perplexity > self.best_ppl):
                if self.cntBatch % self.opt.translate_steps == 0:
                    bleu = self.translator.eval_all(self.model,
                                                    self.validation_data)

        return [loss_per_word,
                loss_per_sample], nll_accuracy, rl_accuracy, bleu, [
                    flu_reward, rel_reward, ans_reward
                ]

    def train_epoch(self, device, epoch):
        ''' Epoch operation in training phase'''
        if self.opt.extra_shuffle and epoch > self.opt.curriculum:
            self.logger.info('Shuffling...')
            self.training_data.shuffle()

        self.model.train()

        total_rl_loss, total_nll_loss, n_sample_total, n_sample_correct = 0, 0, 0, 0
        n_word_total, n_word_correct, report_n_word_total, report_n_word_correct = 0, 0, 0, 0
        report_total_rl_loss, report_total_nll_loss, report_n_sample_total, report_n_sample_correct = 0, 0, 0, 0

        batch_order = torch.randperm(len(self.training_data))

        for idx in tqdm(range(len(self.training_data)),
                        mininterval=2,
                        desc='  - (Training)   ',
                        leave=False):

            batch_idx = batch_order[idx] if epoch > self.opt.curriculum else idx
            batch = self.training_data[batch_idx]

            ##### ==================== prepare data ==================== #####
            inputs, max_length, gold, copy = preprocess_batch(
                batch,
                separate=self.separate,
                enc_rnn=self.opt.enc_rnn != '',
                dec_rnn=self.opt.dec_rnn != '',
                feature=self.opt.feature,
                dec_feature=self.opt.dec_feature,
                answer=self.answer,
                ans_feature=self.opt.ans_feature,
                sep_id=self.sep_id,
                copy=self.opt.copy,
                attn_mask=self.is_attn_mask,
                device=device)
            copy_gold, copy_switch = copy[0], copy[1]

            ##### ==================== forward ==================== #####
            self.model.zero_grad()
            self.optimizer.zero_grad()

            rst = self.model(inputs,
                             max_length=max_length,
                             rl_type=self.opt.rl)

            ##### ==================== backward ==================== #####
            ##=== rl loss ===##
            flu_rl_inputs, flu_discriminator = None, None
            if 'fluency' in self.opt.rl:
                flu_rl_inputs = preprocess_rl_batch(
                    rst['decoded_text'], rst['rand_decoded_text'], 'fluency',
                    self.tgt_vocab, self.opt.rl_device['fluency'])
                flu_discriminator = self.discriminator['fluency']
                flu_discriminator.eval()

            rel_rl_inputs, rel_discriminator = None, None
            if 'relevance' in self.opt.rl:
                inputing = (inputs['encoder']['src_seq'], rst['decoded_text'],
                            inputs['encoder']['lengths'],
                            rst['rand_decoded_text'])
                rel_rl_inputs = preprocess_rl_batch(
                    inputing, None, 'relevance', self.tgt_vocab,
                    self.opt.rl_device['relevance'])
                rel_discriminator = self.discriminator['relevance']
                rel_discriminator.eval()

            ans_rl_inputs, ans_discriminator = None, None
            if 'answerability' in self.opt.rl:
                inputing = (inputs['encoder']['src_seq'], rst['decoded_text'],
                            inputs['encoder']['lengths'],
                            rst['rand_decoded_text'])
                ans_rl_inputs, rand_rl_inputs = preprocess_rl_batch(
                    inputing, None, 'answerability', self.tgt_vocab,
                    self.opt.rl_device['answerability'])
                ans_discriminator = self.discriminator['answerability']
                ans_discriminator.eval()

            if self.opt.rl:
                rl_loss, n_correct = self.cal_rl_loss(
                    rst['pred'], rst['decoded_text'], flu_rl_inputs,
                    rel_rl_inputs, ans_rl_inputs, flu_discriminator,
                    rel_discriminator, ans_discriminator)

            if len(self.opt.gpus) > 1:
                rl_loss = rl_loss.mean()  # mean() to average on multi-gpu.
            loss = rl_loss

            ###=== NLL loss ===##
            rst = self.model(inputs, max_length=max_length)

            loss_input = {}
            loss_input['pred'], loss_input['gold'] = rst['pred'], gold
            if self.opt.copy:
                loss_input['copy_pred'], loss_input['copy_gate'] = rst[
                    'copy_pred'], rst['copy_gate']
                loss_input['copy_gold'], loss_input[
                    'copy_switch'] = copy_gold, copy_switch
            if self.opt.coverage:
                loss_input['coverage_pred'] = rst['coverage_pred']

            nll_loss, word_correct, _ = self.cal_performance(loss_input)
            if len(self.opt.gpus) > 1:
                nll_loss = nll_loss.mean()  # mean() to average on multi-gpu.

            self.cntBatch += 1
            if self.cntBatch % 4 == 0 or loss.item() < -10:
                loss = loss + nll_loss

            ##=== backward ===##
            if math.isnan(loss):
                print('loss catch NaN')
                import ipdb
                ipdb.set_trace()

            self.optimizer.backward(loss)
            self.optimizer.step()

            if math.isnan(self.model.generator.weight.data.contiguous().view(
                    -1).sum().item()):
                print('parameter catch NaN')
                import ipdb
                ipdb.set_trace()

            ##### ==================== note for epoch report & step report ==================== #####
            n_word = gold.ne(Constants.PAD).float().sum().item()
            total_nll_loss += nll_loss.item()
            n_word_total += n_word
            n_word_correct += word_correct

            total_rl_loss += rl_loss.item()
            n_sample_total += n_correct[1]
            n_sample_correct += n_correct[0]

            report_total_nll_loss += nll_loss.item()
            report_n_word_total += n_word
            report_n_word_correct += word_correct

            report_total_rl_loss += rl_loss.item()
            report_n_sample_total += n_correct[1]
            report_n_sample_correct += n_correct[0]

            ##### ==================== evaluation ==================== #####
            if self.cntBatch % self.opt.valid_steps == 0:
                ### ========== evaluation on dev ========== ###
                valid_loss, valid_nll_accu, valid_rl_accu, valid_bleu, rewards = self.eval_step(
                    device, epoch)
                valid_ppl = math.exp(min(valid_loss[0], 16))

                report_avg_nll_loss = report_total_nll_loss / report_n_word_total
                report_avg_rl_loss = report_total_rl_loss / report_n_sample_total
                report_avg_ppl = math.exp(min(report_avg_nll_loss, 16))
                report_avg_nll_accu = report_n_word_correct / report_n_word_total
                report_avg_rl_accu = report_n_sample_correct / report_n_sample_total

                better = False
                # metric = (valid_rl_accu if self.opt.rl in ['relevance', 'answerability', ''] else 1 / (valid_loss[1] + 1e-16)) / valid_ppl
                metric = valid_rl_accu / valid_ppl
                if metric >= self.best_metric:
                    self.best_metric = metric
                    better = True

                report_total_nll_loss, report_total_rl_loss = 0, 0
                report_n_word_total, report_n_sample_total = 0, 0
                report_n_word_correct, report_n_sample_correct = 0, 0

                ### ========== update learning rate ========== ###
                self.optimizer.update_learning_rate(better)

                record_log(self.opt.logfile_train,
                           step=self.cntBatch,
                           rl_loss=report_avg_rl_loss,
                           rl_accu=report_avg_rl_accu,
                           loss=report_avg_nll_loss,
                           accu=report_avg_nll_accu,
                           ppl=math.exp(min(report_avg_nll_loss, 16)),
                           bad_cnt=self.optimizer._bad_cnt,
                           lr=self.optimizer._learning_rate)
                dev_record_log(self.opt.logfile_dev,
                               step=self.cntBatch,
                               loss=valid_loss[0],
                               accu=valid_nll_accu,
                               ppl=valid_ppl,
                               bleu=valid_bleu,
                               rl_loss=valid_loss[1],
                               rl_accu=valid_rl_accu,
                               bad_cnt=self.optimizer._bad_cnt,
                               lr=self.optimizer._learning_rate,
                               flu=rewards[0],
                               rel=rewards[1],
                               ans=rewards[2])

                if self.opt.save_model:
                    self.save_model(better, valid_bleu)

                self.model.train()

        loss_per_word = total_nll_loss / n_word_total
        nll_accuracy = n_word_correct / n_word_total * 100
        rl_accuracy = n_sample_correct / n_sample_total * 100

        return math.exp(min(loss_per_word, 16)), nll_accuracy, rl_accuracy

    def train(self, device):
        ''' Start training '''
        self.logger.info(self.model)

        for epoch_i in range(self.opt.epoch):
            self.logger.info('')
            self.logger.info(' *  [ Epoch {0} ]:   '.format(epoch_i))
            start = time.time()
            ppl, nll_accu, rl_accu = self.train_epoch(device, epoch_i + 1)

            self.logger.info(
                ' *  - (Training)   ppl: {ppl: 8.5f}, accuracy: nll - {nll:3.3f} %; rl - {rl:3.3f} %'
                .format(ppl=ppl, nll=nll_accu, rl=rl_accu))
            print('                ' + str(time.time() - start) +
                  ' seconds for epoch ' + str(epoch_i))
Ejemplo n.º 6
0
    def cal_rl_loss(self, pred, decoded_text, flu_rl_inputs, rel_rl_inputs,
                    ans_rl_inputs, flu_discriminator, rel_discriminator,
                    ans_discriminator):
        def _get_n_best(logits, n_best_size):
            index_and_score = sorted(enumerate(logits),
                                     key=lambda x: x[1],
                                     reverse=True)

            best_indexes = []
            for i in range(len(index_and_score)):
                if i >= n_best_size:
                    break
                best_indexes.append(index_and_score[i])

            return best_indexes

        def _get_score(start_logits, end_logits, indexes):
            start_logits = _get_n_best(start_logits[indexes[0]:indexes[1]], 5)
            end_logits = _get_n_best(end_logits[indexes[0]:indexes[1]], 5)
            b_scores = [[0, 0]]

            for start in start_logits:
                for end in end_logits:
                    if start[0] <= end[0] and end[0] - start[
                            0] < 64:  # TODO: magic number
                        score = torch.tensor([start[1], end[1]])
                        score = math.pow(score[0].item() * score[1].item(),
                                         0.5)
                        b_scores.append([score, end[0] - start[0] + 1])
            b_scores.sort(key=lambda x: x[0], reverse=True)
            return b_scores[0][0]

        batch_size, seq_length, vocab_size = pred.size()
        gold = decoded_text[:, 1:].contiguous()

        ##=== fluency ===##
        flu_reward = 0
        if 'fluency' in self.opt.rl:
            with torch.no_grad():
                output_pred_dicts = flu_discriminator(
                    flu_rl_inputs[0], attention_mask=flu_rl_inputs[1])

                reward_fct = NLLLoss(self.opt, do_reduce=False)

                lm_loss = reward_fct.cal_simple_nll(output_pred_dicts[0],
                                                    flu_rl_inputs[2])
                lm_loss = lm_loss.view(pred.size(0), -1).mean(-1)
                scores = torch.exp(lm_loss).to(pred.device)

            flu_reward = scores.data.sum().item()
            flu_scores_scale = self.opt.flu_alpha - scores.data

        ##=== relevance ===##
        rel_reward = 0
        if 'relevance' in self.opt.rl:
            with torch.no_grad():
                output = rel_discriminator(rel_rl_inputs[0],
                                           token_type_ids=rel_rl_inputs[1])
                # get the output logits for [CLS]
                logits = output[0].contiguous().to(pred.device)
            scores = torch.softmax(logits, dim=1).transpose(0,
                                                            1)[1].contiguous()

            rel_reward = scores.data.sum().item()
            rel_scores_scale = torch.log(self.opt.rel_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== answerability ===##
        ans_reward = 0
        if 'answerability' in self.opt.rl:
            with torch.no_grad():
                batch_start_logits, batch_end_logits = ans_discriminator(
                    ans_rl_inputs[0], ans_rl_inputs[1], ans_rl_inputs[2])
            scores, rand_scores = [], []
            for b in range(batch_start_logits.size(0)):
                start_logits = torch.softmax(batch_start_logits[b], dim=-1)
                end_logits = torch.softmax(batch_end_logits[b], dim=-1)
                score = _get_score(start_logits.detach().cpu().tolist(),
                                   end_logits.detach().cpu().tolist(),
                                   ans_rl_inputs[3][b])
                scores.append(score)
            scores = torch.tensor(scores, device=pred.device)

            ans_reward = scores.data.sum().item()
            ans_scores_scale = torch.log(self.opt.ans_alpha /
                                         (1 - scores.data + 1e-16))

        ##=== combination ===##
        scores = 0
        if 'fluency' in self.opt.rl:
            scores += flu_scores_scale * self.opt.flu_gamma
        if 'relevance' in self.opt.rl:
            scores += rel_scores_scale * self.opt.rel_gamma
        if 'answerability' in self.opt.rl:
            scores += ans_scores_scale * self.opt.ans_gamma

        n_correct = scores.gt(0).float().sum().item()
        weights = [(batch_size - n_correct) / batch_size,
                   n_correct / batch_size]
        weights = [1 / 3, 2 / 3] if weights[0] > 1 / 3 else weights
        scores_scale_rgt = scores.gt(0).float() * scores * weights[1]
        scores_scale_wrg = scores.lt(0).float() * scores * weights[0]
        scores_scale = scores_scale_rgt + scores_scale_wrg

        log_prb = self.rl_loss.cal_simple_nll(pred.contiguous(),
                                              gold.contiguous()).view(
                                                  batch_size, -1).mean(-1)
        loss = torch.sum(scores_scale * log_prb)

        return loss, [
            n_correct, batch_size, [flu_reward, rel_reward, ans_reward]
        ]
Ejemplo n.º 7
0
def main(opt):
    tokenizer = BertTokenizer.from_pretrained(opt.pre_model)
    ###========== Load Data ==========###
    train_data = filter_data(opt.train_src, opt.train_tgt, tokenizer)
    valid_data = filter_data(opt.valid_src, opt.valid_tgt, tokenizer)
    ###========== Get Index ==========###
    options = {'transf':False, 'separate':False, 'tgt':False}
    src_vocab = Vocab.from_opt(pretrained=opt.pre_model, opt=options)
    options = {'lower':False, 'mode':'size', 'size':1000, 'frequency':1,
               'transf':False, 'separate':False, 'tgt':False}
    tgt_vocab = Vocab.from_opt(corpus=train_data['tgt'], opt=options)
    train_src_idx = [src_vocab.convertToIdx(sent) for sent in train_data['src']]
    valid_src_idx = [src_vocab.convertToIdx(sent) for sent in valid_data['src']]
    train_tgt_idx = [tgt_vocab.convertToIdx(sent) for sent in train_data['tgt']]
    valid_tgt_idx = [tgt_vocab.convertToIdx(sent) for sent in valid_data['tgt']]
    ###========== Get Data ==========###
    train_data = Dataset({'src':train_src_idx, 'tgt':train_tgt_idx, 'feature':[train_data['idx']]}, 
                         opt.batch_size, feature=True, opt_cuda=opt.gpus)
    valid_data = Dataset({'src':valid_src_idx, 'tgt':valid_tgt_idx, 'feature':[valid_data['idx']]}, 
                         opt.batch_size, feature=True, opt_cuda=opt.gpus)
    opt.tgt_vocab_size = tgt_vocab.size
    ###========== Prepare Model ==========###
    device = torch.device('cuda')
    encoder = BertModel.from_pretrained(opt.pre_model)
    classifier = nn.Sequential(
        nn.Linear(768 // opt.maxout_pool_size, opt.tgt_vocab_size),     # TODO: fix this magic number later (hidden size of the model)
        nn.Softmax(dim=1)
    )
    model = NERTagger(encoder, classifier, device).to(device)
    for _, para in model.classifier.named_parameters():
        if para.dim() == 1:
            para.data.normal_(0, math.sqrt(6 / (1 + para.size(0))))
        else:
            nn.init.xavier_normal(para, math.sqrt(3))
    if len(opt.gpus) > 1:
        model = nn.DataParallel(model, device_ids=opt.gpus)
    ###========== Prepare for training ==========###
    opt.optim = 'adam'
    opt.decay_method = ''
    opt.learning_rate = 3e-5
    opt.learning_rate_decay = 1
    opt.decay_steps = 10000000
    opt.start_decay_steps = 10000000000
    opt.max_grad_norm = 5
    opt.max_weight_value = 20
    opt.decay_bad_cnt = 5
    optimizer = Optimizer.from_opt(model, opt)
    
    weight = torch.ones(opt.tgt_vocab_size)
    weight[0] = 0       # TODO: fix this magic number later (PAD)
    loss = NLLLoss(opt, weight, size_average=False)
    if opt.gpus:
        loss.cuda()
    ###========== Training ==========###
    best_val = 0

    def eval_model(M, D, L):
        M.eval()

        all_loss, all_accu, all_words = 0, 0, 0
        for i in tqdm(range(len(D)), mininterval=2, desc='  - (Validation)  ', leave=False):
            B = D[i]
            s, t, sid = B['src'][0], B['tgt'], B['feat'][0][0]
            t = t.transpose(0, 1)
            P = M(s, sid)
            lv, G = L.cal_loss_ner(P, t)

            all_loss += lv.item()
            all_words += P.size(0)
            P = P.max(1)[1]
            n_correct = P.eq(G.view(-1))
            n_correct = n_correct.sum().item()
            all_accu += n_correct

        return all_loss/all_words, all_accu/all_words

    def save_model(M, score, best_val, opt):
        if score > best_val:
            model_to_save = M.module.encoder if hasattr(M, 'module') else M.encoder  # Only save the model it-self
            output_model_file = os.path.join(opt.output_dir, "pytorch_model_" + str(round(score * 100, 2)) + ".bin")
            torch.save(model_to_save.state_dict(), output_model_file)
        print('validation', score)

    for _ in range(opt.num_train_epochs):
        train_data.shuffle()
        model.train()
        batch_order = torch.randperm(len(train_data))
        loss_print, words_cnt, accuracy = 0, 0, 0
        for idx in tqdm(range(len(train_data)), mininterval=2, desc='  - (Training)  ', leave=False):
            batch_idx = batch_order[idx]
            batch = train_data[batch_idx]

            src, tgt, src_idx = batch['src'][0], batch['tgt'], batch['feat'][0][0]
            tgt = tgt.transpose(0, 1)

            out = model(src, src_idx)
            loss_val, gold = loss.cal_loss_ner(out, tgt)
            if len(opt.gpus) > 1:
                loss_val = loss_val.mean()  # mean() to average on multi-gpu.
            if math.isnan(loss_val.item()) or loss_val.item() > 1e20:
                print('catch NaN')
                import ipdb; ipdb.set_trace()
            loss_val.backward()

            optimizer.step()
            optimizer.zero_grad()

            loss_print += loss_val.item()
            words_cnt += out.size(0)
            pred = out.max(1)[1]
            n_correct = pred.eq(gold.view(-1))
            n_correct = n_correct.sum().item()
            accuracy += n_correct
            if idx % 1000 == 0:
                loss_print /= words_cnt
                accuracy /= words_cnt
                print('loss', loss_print)
                print('accuracy', accuracy)
                loss_val, words_cnt, accuracy = 0, 0, 0
                if idx % 2000 == 0:
                    loss_val, accuracy_val = eval_model(model, valid_data, loss)
                    save_model(model, accuracy_val, best_val, opt)
                    if accuracy_val > best_val:
                        best_val = accuracy_val
    
    model_to_save = model.module.encoder if hasattr(model, 'module') else model.encoder  # Only save the model it-self
    output_model_file = os.path.join(opt.output_dir, "pytorch_model.bin")
    torch.save(model_to_save.state_dict(), output_model_file)