Ejemplo n.º 1
0
    def __init__(self):
        self.reader = MultiWozReader()
        if len(cfg.cuda_device) == 1:
            self.m = DAMD(self.reader)
        else:
            m = DAMD(self.reader)
            self.m = torch.nn.DataParallel(m, device_ids=cfg.cuda_device)
            # print(self.m.module)
        self.evaluator = MultiWozEvaluator(self.reader)  # evaluator class
        if cfg.cuda: self.m = self.m.cuda()  #cfg.cuda_device[0]
        self.optim = Adam(lr=cfg.lr,
                          params=filter(lambda x: x.requires_grad,
                                        self.m.parameters()),
                          weight_decay=5e-5)
        self.base_epoch = -1

        if cfg.limit_bspn_vocab:
            self.reader.bspn_masks_tensor = {}
            for key, values in self.reader.bspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.bspn_masks_tensor[key] = v_
        if cfg.limit_aspn_vocab:
            self.reader.aspn_masks_tensor = {}
            for key, values in self.reader.aspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.aspn_masks_tensor[key] = v_
Ejemplo n.º 2
0
    def __init__(self, device):
        self.device = device
        # initialize tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path)
        # cfg.tokenizer = tokenizer

        # initialize multiwoz reader
        self.reader = MultiWozReader(self.tokenizer)

        # create model: gpt2
        self.model = GPT2LMHeadModel.from_pretrained(cfg.gpt_path)
        if cfg.mode == 'train':
            self.model.resize_token_embeddings(len(self.tokenizer))
        self.model.to(self.device)  # single gpu

        #
        self.evaluator = MultiWozEvaluator(self.reader)
        if cfg.save_log and cfg.mode == 'train':
            self.tb_writer = SummaryWriter(log_dir='./log')
        else:
            self.tb_writer = None
Ejemplo n.º 3
0
class Modal(object):
    def __init__(self, device):
        self.device = device
        # initialize tokenizer
        self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path)
        # cfg.tokenizer = tokenizer

        # initialize multiwoz reader
        self.reader = MultiWozReader(self.tokenizer)

        # create model: gpt2
        self.model = GPT2LMHeadModel.from_pretrained(cfg.gpt_path)
        if cfg.mode == 'train':
            self.model.resize_token_embeddings(len(self.tokenizer))
        self.model.to(self.device)  # single gpu

        #
        self.evaluator = MultiWozEvaluator(self.reader)
        if cfg.save_log and cfg.mode == 'train':
            self.tb_writer = SummaryWriter(log_dir='./log')
        else:
            self.tb_writer = None

    def get_optimizers(self):
        """
        Setup the optimizer and the learning rate scheduler.

        from transformers.Trainer

        parameters from cfg: lr (1e-3); warmup_steps
        """
        # Prepare optimizer and schedule (linear warmup and decay)
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                cfg.weight_decay,
            },
            {
                "params": [
                    p for n, p in self.model.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.lr)
        num_training_steps = self.reader.set_stats['train']['num_dials'] *\
            cfg.epoch_num // (cfg.gradient_accumulation_steps*cfg.batch_size)
        num_warmup_steps = cfg.warmup_steps if cfg.warmup_steps >= 0 else int(
            num_training_steps * 0.2)
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=num_warmup_steps,
            num_training_steps=num_training_steps)
        return optimizer, scheduler

    def log_first_inputs(self, inputs):
        tokenizer = self.tokenizer
        logging.info("**** Input Examples: ****")
        for context in inputs['contexts'][:4]:
            # ubar = tokenizer.convert_ids_to_tokens(context)
            # ubar = tokenizer.convert_tokens_to_string(context)
            # ubar = " ".join(ubar)
            ubar = tokenizer.decode(context)
            logging.info(ubar)

    def add_torch_input(self, inputs):
        # to tensor and to device
        contexts_tensor = torch.from_numpy(inputs['contexts_np']).long()
        contexts_tensor = contexts_tensor.to(self.device)
        inputs['contexts_tensor'] = contexts_tensor
        return inputs

    def add_torch_input_eval(self, inputs):
        # inputs: context
        inputs['context_tensor'] = torch.tensor([inputs['context']
                                                 ]).to(self.device)
        return inputs

    def calculate_loss_and_accuracy(self, outputs, labels):
        # GPT2-chicahat/train.py
        lm_logits = outputs[0]

        shift_logits = lm_logits[..., :-1, :].contiguous()
        shift_labels = labels[..., 1:].contiguous()

        pad_id = cfg.pad_id
        loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction='sum')
        loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                        shift_labels.view(-1))

        # avg loss
        not_ignore = shift_labels.ne(pad_id)
        num_targets = not_ignore.long().sum().item()

        loss /= num_targets
        return loss

    def train_URURU(self):
        """
        URURU
        """
        all_batches = self.reader.get_batches('train')
        # compute num_training_steps in get_batches()
        optimizer, scheduler = self.get_optimizers()

        # log info
        set_stats = self.reader.set_stats['train']
        logging.info("***** Running training *****")
        logging.info(
            "  Num Training steps(one turn in a batch of dialogs) per epoch = %d",
            set_stats['num_training_steps_per_epoch'])
        logging.info("  Num Turns = %d", set_stats['num_turns'])
        logging.info("  Num Dialogs = %d", set_stats['num_dials'])
        logging.info("  Num Epochs = %d", cfg.epoch_num)
        logging.info("  Batch size  = %d", cfg.batch_size)
        logging.info("  Gradient Accumulation steps = %d",
                     cfg.gradient_accumulation_steps)
        logging.info(
            "  Total optimization steps = %d",
            set_stats['num_training_steps_per_epoch'] * cfg.epoch_num //
            cfg.gradient_accumulation_steps)

        # tb writer
        if self.tb_writer is not None:
            self.tb_writer.add_text('cfg', json.dumps(cfg.__dict__, indent=2))
            # self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        log_inputs = 2
        global_step = 0
        sw = time.time()

        for epoch in range(cfg.epoch_num):
            epoch_step = 0
            tr_loss = 0.0
            logging_loss = 0.0
            btm = time.time()
            oom_time = 0
            self.model.zero_grad()

            data_iterator = self.reader.get_data_iterator(all_batches)

            for batch_idx, dial_batch in enumerate(data_iterator):
                pv_batch = None
                for turn_num, turn_batch in enumerate(dial_batch):
                    first_turn = (turn_num == 0)
                    inputs = self.reader.convert_batch_turn(
                        turn_batch, pv_batch, first_turn)
                    pv_batch = inputs['labels']
                    try:  # avoid OOM
                        self.model.train()
                        if log_inputs > 0:  # log inputs for the very first two turns
                            self.log_first_inputs(inputs)
                            log_inputs -= 1

                        # to tensor
                        inputs = self.add_torch_input(inputs)
                        # loss
                        outputs = self.model(inputs['contexts_tensor'])
                        # outputs = self.model(inputs['contexts_tensor']) # debugging with GPT2Model
                        loss = self.calculate_loss_and_accuracy(
                            outputs, labels=inputs['contexts_tensor'])
                        loss.backward()
                        tr_loss += loss.item()
                        torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                       5.0)
                        epoch_step += 1

                        # step, wrt gradient_accumulation_steps, clip grad norm
                        if (epoch_step+1) % cfg.gradient_accumulation_steps == 0 or(
                            # end of an epoch
                            (epoch_step + \
                            1) == set_stats['num_training_steps_per_epoch']
                        ):
                            optimizer.step()
                            scheduler.step()
                            optimizer.zero_grad()
                            # global_step: actual step the optimizer took
                            global_step += 1

                            logs = {}  # for tb writer
                            # logging: loss, lr... after certain amount of steps
                            if cfg.report_interval > 0 and global_step % cfg.report_interval == 0:
                                loss_scalar = (tr_loss - logging_loss) / \
                                    cfg.report_interval
                                logging_loss = tr_loss
                                logs['loss'] = loss_scalar
                                logging.info(
                                    'Global step: {}, epoch step: {}, interval loss: {:.4f}'
                                    .format(global_step, epoch_step,
                                            loss_scalar))
                                # validate
                                # add to tensorboard...
                                if cfg.evaluate_during_training and loss_scalar < 10:
                                    results = self.validate()
                                    for k, v in results.items():
                                        eval_key = "eval_{}".format(k)
                                        logs[eval_key] = v

                                if self.tb_writer:
                                    for k, v in logs.items():
                                        self.tb_writer.add_scalar(
                                            k, v, global_step)
                                # save model...

                    except RuntimeError as exception:
                        if "out of memory" in str(exception):
                            max_length = max(inputs['lengths'])
                            oom_time += 1
                            logging.info(
                                "WARNING: ran out of memory,times: {}, batch size: {}, max_len: {}"
                                .format(oom_time, cfg.batch_size, max_length))
                            if hasattr(torch.cuda, 'empty_cache'):
                                torch.cuda.empty_cache()
                        else:
                            logging.info(str(exception))
                            raise exception
            logging.info(
                'Train epoch time: {:.2f} min, epoch loss: {:.4f}'.format(
                    (time.time() - btm) / 60, tr_loss))
            # save model after every epoch
            # if epoch > 30 and tr_loss/epoch_step < 0.6:
            self.save_model(epoch, tr_loss / epoch_step)

    def train(self):
        """
        UBARU
        """
        all_batches = self.reader.get_batches('train')
        # compute num_training_steps in get_batches()
        optimizer, scheduler = self.get_optimizers()

        # log info
        set_stats = self.reader.set_stats['train']
        logging.info("***** Running training *****")
        logging.info(
            "  Num Training steps(one turn in a batch of dialogs) per epoch = %d",
            set_stats['num_training_steps_per_epoch'])
        logging.info("  Num Turns = %d", set_stats['num_turns'])
        logging.info("  Num Dialogs = %d", set_stats['num_dials'])
        logging.info("  Num Epochs = %d", cfg.epoch_num)
        logging.info("  Batch size  = %d", cfg.batch_size)
        logging.info("  Gradient Accumulation steps = %d",
                     cfg.gradient_accumulation_steps)
        logging.info(
            "  Total optimization steps = %d",
            set_stats['num_dials'] * cfg.epoch_num //
            (cfg.gradient_accumulation_steps * cfg.batch_size))

        # tb writer
        if self.tb_writer is not None:
            self.tb_writer.add_text('cfg', json.dumps(cfg.__dict__, indent=2))
            # self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        log_inputs = 2
        global_step = 0
        sw = time.time()

        for epoch in range(cfg.epoch_num):
            epoch_step = 0
            tr_loss = 0.0
            logging_loss = 0.0
            btm = time.time()
            oom_time = 0
            self.model.zero_grad()

            data_iterator = self.reader.get_nontranspose_data_iterator(
                all_batches)

            for batch_idx, dial_batch in enumerate(data_iterator):
                inputs = self.reader.convert_batch_session(dial_batch)
                try:  # avoid OOM
                    self.model.train()
                    if log_inputs > 0:  # log inputs for the very first two turns
                        self.log_first_inputs(inputs)
                        log_inputs -= 1

                    # to tensor
                    inputs = self.add_torch_input(inputs)
                    # loss
                    outputs = self.model(inputs['contexts_tensor'])
                    # outputs = self.model(inputs['contexts_tensor']) # debugging with GPT2Model
                    loss = self.calculate_loss_and_accuracy(
                        outputs, labels=inputs['contexts_tensor'])
                    loss.backward()
                    tr_loss += loss.item()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                                   5.0)
                    epoch_step += 1

                    # step, wrt gradient_accumulation_steps, clip grad norm
                    if (epoch_step+1) % cfg.gradient_accumulation_steps == 0 or(
                        # end of an epoch
                        (epoch_step + \
                         1) == set_stats['num_training_steps_per_epoch']
                    ):
                        optimizer.step()
                        scheduler.step()
                        optimizer.zero_grad()
                        # global_step: actual step the optimizer took
                        global_step += 1

                        logs = {}  # for tb writer
                        # logging: loss, lr... after certain amount of steps
                        if cfg.report_interval > 0 and global_step % cfg.report_interval == 0:
                            loss_scalar = (tr_loss - logging_loss) / \
                                cfg.report_interval
                            logging_loss = tr_loss
                            logs['loss'] = loss_scalar
                            logging.info(
                                'Global step: {}, epoch step: {}, interval loss: {:.4f}'
                                .format(global_step, epoch_step, loss_scalar))
                            # validate
                            # add to tensorboard...
                            if cfg.evaluate_during_training and loss_scalar < 10:
                                results = self.validate()
                                for k, v in results.items():
                                    eval_key = "eval_{}".format(k)
                                    logs[eval_key] = v

                            if self.tb_writer:
                                for k, v in logs.items():
                                    self.tb_writer.add_scalar(
                                        k, v, global_step)
                            # save model...

                except RuntimeError as exception:
                    if "out of memory" in str(exception):
                        max_length = max(inputs['lengths'])
                        oom_time += 1
                        logging.info(
                            "WARNING: ran out of memory,times: {}, batch size: {}, max_len: {}"
                            .format(oom_time, cfg.batch_size, max_length))
                        if hasattr(torch.cuda, 'empty_cache'):
                            torch.cuda.empty_cache()
                    else:
                        logging.info(str(exception))
                        raise exception
            logging.info(
                'Train epoch time: {:.2f} min, epoch loss: {:.4f}'.format(
                    (time.time() - btm) / 60, tr_loss))
            # save model after every epoch
            # if epoch > 10 or tr_loss/epoch_step < 1:
            self.save_model(epoch, tr_loss / epoch_step)

    def save_model(self, epoch, loss):
        save_path = os.path.join(
            cfg.exp_path, 'epoch{}_trloss{:.2f}_gpt2'.format(epoch + 1, loss))
        if not os.path.exists(save_path):
            os.mkdir(save_path)
        logging.info('Saving model checkpoint to %s', save_path)
        # save gpt2
        self.model.save_pretrained(save_path)
        # save tokenizer
        self.tokenizer.save_pretrained(save_path)
        # save cfg

    def validate_URURU(self, data='dev', do_test=False):
        # predict one dialog/ one turn at a time
        self.model.eval()

        # all_batches = self.reader.get_batches('dev')
        # data_iterator = self.reader.get_data_iterator(all_batches)
        eval_data = self.reader.get_eval_data(data)

        set_stats = self.reader.set_stats[data]
        logging.info("***** Running Evaluation *****")
        logging.info("  Num Turns = %d", set_stats['num_turns'])
        # logging.info("  Num Dialogs = %d", set_stats['num_dials'])

        # valid_losses = []
        btm = time.time()
        result_collection = {}
        with torch.no_grad():
            eval_pbar = eval_data
            for dial_idx, dialog in enumerate(eval_pbar):

                pv_turn = {}
                for turn_idx, turn in enumerate(dialog):
                    first_turn = (turn_idx == 0)
                    inputs = self.reader.convert_turn_eval_URURU(
                        turn, pv_turn, first_turn)
                    inputs = self.add_torch_input_eval(inputs)

                    # fail to generate new tokens, if max_length not set
                    context_length = len(inputs['context'])
                    if cfg.use_true_curr_bspn:  # generate act, response
                        max_len = 60
                        if not cfg.use_true_curr_aspn:
                            max_len = 80

                        outputs = self.model.generate(
                            input_ids=inputs['context_tensor'],
                            max_length=context_length + max_len,
                            temperature=0.7,  # top_p=0.9, num_beams=4,
                            pad_token_id=self.tokenizer.eos_token_id,
                            eos_token_id=self.tokenizer.encode(['<eos_r>'])[0])
                        #   no_repeat_ngram_size=4
                        # turn['generated'] = self.tokenizer.decode(outputs[0])

                        # resp_gen, need to trim previous context
                        generated = outputs[0].cpu().numpy().tolist()
                        generated = generated[context_length - 1:]

                        try:
                            decoded = self.decode_generated_act_resp(generated)
                        except ValueError as exception:
                            logging.info(str(exception))
                            logging.info(self.tokenizer.decode(generated))
                            decoded = {'resp': [], 'bspn': [], 'aspn': []}

                    else:  # predict bspn, access db, then generate act and resp
                        outputs = self.model.generate(
                            input_ids=inputs['context_tensor'],
                            max_length=context_length + 60,
                            temperature=0.7,  # top_p=0.9, num_beams=4,
                            pad_token_id=self.tokenizer.eos_token_id,
                            eos_token_id=self.tokenizer.encode(['<eos_b>'])[0])
                        generated_bs = outputs[0].cpu().numpy().tolist()
                        # generated_bs = generated_bs[context_length-1:]
                        bspn_gen = self.decode_generated_bspn(
                            generated_bs[context_length - 1:])
                        # check DB result
                        if cfg.use_true_db_pointer:
                            # db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']), turn['turn_domain'])
                            db = turn['db']
                        else:
                            db_result = self.reader.bspan_to_DBpointer(
                                self.tokenizer.decode(bspn_gen),
                                turn['turn_domain'])
                            db = self.tokenizer.convert_tokens_to_ids(
                                self.tokenizer.tokenize(
                                    '<sos_db> ' + db_result +
                                    ' <eos_db>')) + self.tokenizer.encode(
                                        ['<sos_a>'])
                        inputs['context_tensor_db'] = torch.tensor([
                            inputs['context'][:-1] + bspn_gen + db
                        ]).to(self.device)
                        context_length = len(inputs['context_tensor_db'][0])
                        outputs_db = self.model.generate(
                            input_ids=inputs['context_tensor_db'],
                            max_length=context_length + 80,
                            temperature=0.7,  # top_p=0.9, num_beams=4,
                            pad_token_id=self.tokenizer.eos_token_id,
                            eos_token_id=self.tokenizer.encode(['<eos_r>'])[0])
                        generated_ar = outputs_db[0].cpu().numpy().tolist()
                        generated_ar = generated_ar[context_length - 1:]
                        try:
                            decoded = self.decode_generated_act_resp(
                                generated_ar)
                            decoded['bspn'] = bspn_gen
                        except ValueError as exception:
                            logging.info(str(exception))
                            logging.info(self.tokenizer.decode(generated_ar))
                            decoded = {'resp': [], 'bspn': [], 'aspn': []}

                    turn['resp_gen'] = decoded['resp']
                    turn['bspn_gen'] = turn[
                        'bspn'] if cfg.use_true_curr_bspn else decoded['bspn']
                    turn['aspn_gen'] = turn[
                        'aspn'] if cfg.use_true_curr_aspn else decoded['aspn']
                    turn['dspn_gen'] = turn['dspn']

                    # check DB results
                    # db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']), turn['turn_domain'])
                    # if db_result[0] == 1: # no match
                    #     print('gt:', self.tokenizer.decode(turn['aspn']), '     |gen:', self.tokenizer.decode(decoded['aspn']))
                    #     print('gen_resp: ', self.tokenizer.decode(decoded['resp']))
                    #     print('gt_resp: ', self.tokenizer.decode(turn['resp']), '\n')

                    pv_turn['labels'] = inputs[
                        'labels']  # all true previous context
                    pv_turn['resp'] = turn[
                        'resp'] if cfg.use_true_prev_resp else decoded['resp']
                    # pv_turn['bspn'] = turn['bspn'] if cfg.use_true_prev_bspn else decoded['bspn']
                    # pv_turn['db'] = db
                    # pv_turn['aspn'] = turn['aspn'] if cfg.use_true_prev_aspn else decoded['aspn']
                    # pv_turn = inputs['labels']

                result_collection.update(
                    self.reader.inverse_transpose_turn(dialog))

        logging.info("inference time: {:.2f} min".format(
            (time.time() - btm) / 60))
        # score
        btm = time.time()
        results, _ = self.reader.wrap_result_lm(result_collection)
        bleu, success, match = self.evaluator.validation_metric(results)
        logging.info("Scoring time: {:.2f} min".format(
            (time.time() - btm) / 60))
        score = 0.5 * (success + match) + bleu
        valid_loss = 130 - score
        logging.info(
            'validation [CTR] match: %2.2f  success: %2.2f  bleu: %2.2f    score: %.2f'
            % (match, success, bleu, score))
        eval_results = {}
        eval_results['bleu'] = bleu
        eval_results['success'] = success
        eval_results['match'] = match

        return eval_results

    def validate(self, data='dev', do_test=False):
        # predict one dialog/ one turn at a time
        self.model.eval()

        # all_batches = self.reader.get_batches('dev')
        # data_iterator = self.reader.get_data_iterator(all_batches)
        eval_data = self.reader.get_eval_data(data)

        set_stats = self.reader.set_stats[data]
        logging.info("***** Running Evaluation *****")
        logging.info("  Num Turns = %d", set_stats['num_turns'])
        # logging.info("  Num Dialogs = %d", set_stats['num_dials'])

        # valid_losses = []
        btm = time.time()
        result_collection = {}
        with torch.no_grad():
            for dial_idx, dialog in enumerate(eval_data):

                pv_turn = {}
                for turn_idx, turn in enumerate(dialog):
                    first_turn = (turn_idx == 0)
                    inputs = self.reader.convert_turn_eval(
                        turn, pv_turn, first_turn)
                    inputs = self.add_torch_input_eval(inputs)

                    # fail to generate new tokens, if max_length not set
                    context_length = len(inputs['context'])
                    if cfg.use_true_curr_bspn:  # generate act, response
                        max_len = 60
                        if not cfg.use_true_curr_aspn:
                            max_len = 80

                        outputs = self.model.generate(
                            input_ids=inputs['context_tensor'],
                            max_length=context_length + max_len,
                            temperature=0.7,  # top_p=0.9, num_beams=4,
                            pad_token_id=self.tokenizer.eos_token_id,
                            eos_token_id=self.tokenizer.encode(['<eos_r>'])[0])
                        #   no_repeat_ngram_size=4
                        # turn['generated'] = self.tokenizer.decode(outputs[0])

                        # resp_gen, need to trim previous context
                        generated = outputs[0].cpu().numpy().tolist()
                        generated = generated[context_length - 1:]

                        try:
                            decoded = self.decode_generated_act_resp(generated)
                        except ValueError as exception:
                            logging.info(str(exception))
                            logging.info(self.tokenizer.decode(generated))
                            decoded = {'resp': [], 'bspn': [], 'aspn': []}

                    else:  # predict bspn, access db, then generate act and resp
                        outputs = self.model.generate(
                            input_ids=inputs['context_tensor'],
                            max_length=context_length + 60,
                            temperature=0.7,  # top_p=0.9, num_beams=4,
                            pad_token_id=self.tokenizer.eos_token_id,
                            eos_token_id=self.tokenizer.encode(['<eos_b>'])[0])
                        generated_bs = outputs[0].cpu().numpy().tolist()
                        # generated_bs = generated_bs[context_length-1:]
                        bspn_gen = self.decode_generated_bspn(
                            generated_bs[context_length - 1:])
                        # check DB result
                        if cfg.use_true_db_pointer:
                            # db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']), turn['turn_domain'])
                            db = turn['db']
                        else:
                            db_result = self.reader.bspan_to_DBpointer(
                                self.tokenizer.decode(bspn_gen),
                                turn['turn_domain'])
                            db = self.tokenizer.convert_tokens_to_ids(
                                self.tokenizer.tokenize(
                                    '<sos_db> ' + db_result +
                                    ' <eos_db>')) + self.tokenizer.encode(
                                        ['<sos_a>'])
                        inputs['context_tensor_db'] = torch.tensor([
                            inputs['context'][:-1] + bspn_gen + db
                        ]).to(self.device)
                        context_length = len(inputs['context_tensor_db'][0])
                        outputs_db = self.model.generate(
                            input_ids=inputs['context_tensor_db'],
                            max_length=context_length + 80,
                            temperature=0.7,  # top_p=0.9, num_beams=4,
                            pad_token_id=self.tokenizer.eos_token_id,
                            eos_token_id=self.tokenizer.encode(['<eos_r>'])[0])
                        generated_ar = outputs_db[0].cpu().numpy().tolist()
                        generated_ar = generated_ar[context_length - 1:]
                        try:
                            decoded = self.decode_generated_act_resp(
                                generated_ar)
                            decoded['bspn'] = bspn_gen
                        except ValueError as exception:
                            logging.info(str(exception))
                            logging.info(self.tokenizer.decode(generated_ar))
                            decoded = {'resp': [], 'bspn': [], 'aspn': []}

                    turn['resp_gen'] = decoded['resp']
                    turn['bspn_gen'] = turn[
                        'bspn'] if cfg.use_true_curr_bspn else decoded['bspn']
                    turn['aspn_gen'] = turn[
                        'aspn'] if cfg.use_true_curr_aspn else decoded['aspn']
                    turn['dspn_gen'] = turn['dspn']

                    # check DB results
                    # db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']), turn['turn_domain'])
                    # if db_result[0] == 1: # no match
                    #     print('gt:', self.tokenizer.decode(turn['aspn']), '     |gen:', self.tokenizer.decode(decoded['aspn']))
                    #     print('gen_resp: ', self.tokenizer.decode(decoded['resp']))
                    #     print('gt_resp: ', self.tokenizer.decode(turn['resp']), '\n')

                    pv_turn['labels'] = inputs[
                        'labels']  # all true previous context
                    pv_turn['resp'] = turn[
                        'resp'] if cfg.use_true_prev_resp else decoded['resp']
                    pv_turn['bspn'] = turn[
                        'bspn'] if cfg.use_true_prev_bspn else decoded['bspn']
                    pv_turn[
                        'db'] = turn['db'] if cfg.use_true_curr_bspn else db
                    pv_turn['aspn'] = turn[
                        'aspn'] if cfg.use_true_prev_aspn else decoded['aspn']

                result_collection.update(
                    self.reader.inverse_transpose_turn(dialog))

        logging.info("inference time: {:.2f} min".format(
            (time.time() - btm) / 60))
        # score
        btm = time.time()
        results, _ = self.reader.wrap_result_lm(result_collection)
        bleu, success, match = self.evaluator.validation_metric(results)
        logging.info("Scoring time: {:.2f} min".format(
            (time.time() - btm) / 60))
        score = 0.5 * (success + match) + bleu
        valid_loss = 130 - score
        logging.info(
            'validation [CTR] match: %2.2f  success: %2.2f  bleu: %2.2f    score: %.2f'
            % (match, success, bleu, score))
        eval_results = {}
        eval_results['bleu'] = bleu
        eval_results['success'] = success
        eval_results['match'] = match
        eval_results['score'] = score
        eval_results[
            'result'] = 'validation [CTR] match: %2.2f  success: %2.2f  bleu: %2.2f    score: %.2f' % (
                match, success, bleu, score)

        model_setting, epoch_setting = cfg.eval_load_path.split(
            '/')[1], cfg.eval_load_path.split('/')[2]
        eval_on = '-'.join(cfg.exp_domains)
        if data == 'test':
            eval_on += '_test'
        if not os.path.exists(cfg.log_path):
            os.mkdir(cfg.log_path)
        log_file_name = os.path.join(cfg.log_path,
                                     model_setting + '-' + eval_on + '.json')
        if os.path.exists(log_file_name):
            eval_to_json = json.load(open(log_file_name, 'r'))
            eval_to_json[epoch_setting] = eval_results
            json.dump(eval_to_json, open(log_file_name, 'w'), indent=2)
        else:
            eval_to_json = {}
            eval_to_json[epoch_setting] = eval_results
            json.dump(eval_to_json, open(log_file_name, 'w'), indent=2)
        logging.info('update eval results to {}'.format(log_file_name))
        return eval_results

    def decode_generated_act_resp(self, generated):
        """
        decode generated
        return decoded['resp'] ('bspn', 'aspn')
        """
        decoded = {}
        eos_a_id = self.tokenizer.encode(['<eos_a>'])[0]
        eos_r_id = self.tokenizer.encode(['<eos_r>'])[0]
        eos_b_id = self.tokenizer.encode(['<eos_b>'])[0]

        # eos_r may not exists if gpt2 generated repetitive words.
        if eos_r_id in generated:
            eos_r_idx = generated.index(eos_r_id)
        else:
            eos_r_idx = len(generated) - 1
            logging.info('eos_r not in generated: ' +
                         self.tokenizer.decode(generated))
        # eos_r_idx = generated.index(eos_r_id) if eos_r_id in generated else len(generated)-1

        if cfg.use_true_curr_aspn:  # only predict resp
            decoded['resp'] = generated[:eos_r_idx + 1]
        else:  # predicted aspn, resp
            eos_a_idx = generated.index(eos_a_id)
            decoded['aspn'] = generated[:eos_a_idx + 1]
            decoded['resp'] = generated[eos_a_idx + 1:eos_r_idx + 1]
        # if cfg.use_true_curr_bspn:

        # else:  # predict bspn aspn resp
        #     eos_b_idx = generated.index(eos_b_id)
        #     eos_a_idx = generated.index(eos_a_id)
        #     decoded['bspn'] = generated[: eos_b_idx+1]
        #     decoded['aspn'] = generated[eos_b_idx+1: eos_a_idx+1]
        #     decoded['resp'] = generated[eos_a_idx+1: eos_r_idx+1]
        return decoded

    def decode_generated_bspn(self, generated):
        eos_b_id = self.tokenizer.encode(['<eos_b>'])[0]
        if eos_b_id in generated:
            eos_b_idx = generated.index(eos_b_id)
        else:
            eos_b_idx = len(generated) - 1
        return generated[:eos_b_idx + 1]
Ejemplo n.º 4
0
class Model(object):
    def __init__(self):
        self.reader = MultiWozReader()
        if len(cfg.cuda_device) == 1:
            self.m = DAMD(self.reader)
        else:
            m = DAMD(self.reader)
            self.m = torch.nn.DataParallel(m, device_ids=cfg.cuda_device)
            # print(self.m.module)
        self.evaluator = MultiWozEvaluator(self.reader)  # evaluator class
        if cfg.cuda: self.m = self.m.cuda()  #cfg.cuda_device[0]
        self.optim = Adam(lr=cfg.lr,
                          params=filter(lambda x: x.requires_grad,
                                        self.m.parameters()),
                          weight_decay=5e-5)
        self.base_epoch = -1

        if cfg.limit_bspn_vocab:
            self.reader.bspn_masks_tensor = {}
            for key, values in self.reader.bspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.bspn_masks_tensor[key] = v_
        if cfg.limit_aspn_vocab:
            self.reader.aspn_masks_tensor = {}
            for key, values in self.reader.aspn_masks.items():
                v_ = cuda_(torch.Tensor(values).long())
                self.reader.aspn_masks_tensor[key] = v_

    def add_torch_input(self, inputs, mode='train', first_turn=False):
        need_onehot = [
            'user', 'usdx', 'bspn', 'aspn', 'pv_resp', 'pv_bspn', 'pv_aspn',
            'dspn', 'pv_dspn', 'bsdx', 'pv_bsdx'
        ]
        inputs['db'] = cuda_(torch.from_numpy(inputs['db_np']).float())
        for item in ['user', 'usdx', 'resp', 'bspn', 'aspn', 'bsdx', 'dspn']:
            if not cfg.enable_aspn and item == 'aspn':
                continue
            if not cfg.enable_bspn and item == 'bspn':
                continue
            if not cfg.enable_dspn and item == 'dspn':
                continue
            inputs[item] = cuda_(
                torch.from_numpy(
                    inputs[item + '_unk_np']).long())  # replace oov to <unk>
            if item in ['user', 'usdx', 'resp', 'bspn']:
                inputs[item + '_nounk'] = cuda_(
                    torch.from_numpy(
                        inputs[item +
                               '_np']).long())  # don't replace oov to <unk>
            else:
                inputs[item + '_nounk'] = inputs[item]
            # print(item, inputs[item].size())
            if item in ['resp', 'bspn', 'aspn', 'bsdx', 'dspn']:
                if 'pv_' + item + '_unk_np' not in inputs:
                    continue
                inputs['pv_' + item] = cuda_(
                    torch.from_numpy(inputs['pv_' + item + '_unk_np']).long())
                if item in ['user', 'usdx', 'bspn']:
                    inputs['pv_' + item + '_nounk'] = cuda_(
                        torch.from_numpy(inputs['pv_' + item + '_np']).long())
                    inputs[item + '_4loss'] = self.index_for_loss(item, inputs)
                else:
                    inputs['pv_' + item + '_nounk'] = inputs['pv_' + item]
                    inputs[item + '_4loss'] = inputs[item]
                if 'pv_' + item in need_onehot:
                    inputs['pv_' + item + '_onehot'] = get_one_hot_input(
                        inputs['pv_' + item + '_unk_np'])
            if item in need_onehot:
                inputs[item + '_onehot'] = get_one_hot_input(inputs[item +
                                                                    '_unk_np'])

        if cfg.multi_acts_training and 'aspn_aug_unk_np' in inputs:
            inputs['aspn_aug'] = cuda_(
                torch.from_numpy(inputs['aspn_aug_unk_np']).long())
            inputs['aspn_aug_4loss'] = inputs['aspn_aug']

        return inputs

    def index_for_loss(self, item, inputs):
        raw_labels = inputs[item + '_np']
        if item == 'bspn':
            copy_sources = [
                inputs['user_np'], inputs['pv_resp_np'], inputs['pv_bspn_np']
            ]
        elif item == 'bsdx':
            copy_sources = [
                inputs['usdx_np'], inputs['pv_resp_np'], inputs['pv_bsdx_np']
            ]
        elif item == 'aspn':
            copy_sources = []
            if cfg.use_pvaspn:
                copy_sources.append(inputs['pv_aspn_np'])
            if cfg.enable_bspn:
                copy_sources.append(inputs[cfg.bspn_mode + '_np'])
        elif item == 'dspn':
            copy_sources = [inputs['pv_dspn_np']]
        elif item == 'resp':
            copy_sources = [inputs['usdx_np']]
            if cfg.enable_bspn:
                copy_sources.append(inputs[cfg.bspn_mode + '_np'])
            if cfg.enable_aspn:
                copy_sources.append(inputs['aspn_np'])
        else:
            return
        new_labels = np.copy(raw_labels)
        if copy_sources:
            bidx, tidx = np.where(raw_labels >= self.reader.vocab_size)
            copy_sources = np.concatenate(copy_sources, axis=1)
            for b in bidx:
                for t in tidx:
                    oov_idx = raw_labels[b, t]
                    if len(np.where(copy_sources[b, :] == oov_idx)[0]) == 0:
                        new_labels[b, t] = 2
        return cuda_(torch.from_numpy(new_labels).long())

    def train(self):
        lr = cfg.lr
        prev_min_loss, early_stop_count = 1 << 30, cfg.early_stop_count
        weight_decay_count = cfg.weight_decay_count
        train_time = 0
        sw = time.time()

        for epoch in range(cfg.epoch_num):
            if epoch <= self.base_epoch:
                continue
            self.training_adjust(epoch)
            sup_loss = 0
            sup_cnt = 0
            optim = self.optim
            # data_iterator generatation size: (batch num, turn num, batch size)
            btm = time.time()
            data_iterator = self.reader.get_batches('train')
            for iter_num, dial_batch in enumerate(data_iterator):
                hidden_states = {}
                py_prev = {
                    'pv_resp': None,
                    'pv_bspn': None,
                    'pv_aspn': None,
                    'pv_dspn': None,
                    'pv_bsdx': None
                }
                bgt = time.time()
                for turn_num, turn_batch in enumerate(dial_batch):
                    # print('turn %d'%turn_num)
                    # print(len(turn_batch['dial_id']))
                    optim.zero_grad()
                    first_turn = (turn_num == 0)
                    inputs = self.reader.convert_batch(turn_batch,
                                                       py_prev,
                                                       first_turn=first_turn)
                    inputs = self.add_torch_input(inputs,
                                                  first_turn=first_turn)
                    # total_loss, losses, hidden_states = self.m(inputs, hidden_states, first_turn, mode='train')
                    total_loss, losses = self.m(inputs,
                                                hidden_states,
                                                first_turn,
                                                mode='train')
                    # print('forward completed')
                    py_prev['pv_resp'] = turn_batch['resp']
                    if cfg.enable_bspn:
                        py_prev['pv_bspn'] = turn_batch['bspn']
                        py_prev['pv_bsdx'] = turn_batch['bsdx']
                    if cfg.enable_aspn:
                        py_prev['pv_aspn'] = turn_batch['aspn']
                    if cfg.enable_dspn:
                        py_prev['pv_dspn'] = turn_batch['dspn']

                    total_loss = total_loss.mean()
                    # print('forward time:%f'%(time.time()-test_begin))
                    # test_begin = time.time()
                    total_loss.backward(retain_graph=False)
                    # total_loss.backward(retain_graph=turn_num != len(dial_batch) - 1)
                    # print('backward time:%f'%(time.time()-test_begin))
                    grad = torch.nn.utils.clip_grad_norm_(
                        self.m.parameters(), 5.0)
                    optim.step()
                    sup_loss += float(total_loss)
                    sup_cnt += 1
                    torch.cuda.empty_cache()

                if (iter_num + 1) % cfg.report_interval == 0:
                    logging.info(
                            'iter:{} [total|bspn|aspn|resp|ptr|gate] loss: {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} {:.2f} grad:{:.2f} time: {:.1f} turn:{} '.format(iter_num+1, \
                                float(total_loss),float(losses['bspn']),float(losses['aspn']),float(losses['resp']), float(losses["trade_ptr"]), \
                                float(losses["trade_gating"]),grad, time.time()-btm, turn_num+1))
                    if cfg.enable_dst and cfg.bspn_mode == 'bsdx':
                        logging.info('bspn-dst:{:.3f}'.format(
                            float(losses['bspn'])))
                    if cfg.multi_acts_training:
                        logging.info('aspn-aug:{:.3f}'.format(
                            float(losses['aspn_aug'])))

                # btm = time.time()
                # if (iter_num+1)%40==0:
                #     print('validation checking ... ')
                #     valid_sup_loss, valid_unsup_loss = self.validate(do_test=False)
                #     logging.info('validation loss in epoch %d sup:%f unsup:%f' % (epoch, valid_sup_loss, valid_unsup_loss))

            epoch_sup_loss = sup_loss / (sup_cnt + 1e-8)
            # do_test = True if (epoch+1)%5==0 else False
            do_test = False
            valid_loss = self.validate(do_test=do_test)
            logging.info(
                'epoch: %d, train loss: %.3f, valid loss: %.3f, total time: %.1fmin'
                % (epoch + 1, epoch_sup_loss, valid_loss,
                   (time.time() - sw) / 60))
            # self.save_model(epoch)
            if valid_loss <= prev_min_loss:
                early_stop_count = cfg.early_stop_count
                weight_decay_count = cfg.weight_decay_count
                prev_min_loss = valid_loss
                self.save_model(epoch)
            else:
                early_stop_count -= 1
                weight_decay_count -= 1
                logging.info('epoch: %d early stop countdown %d' %
                             (epoch + 1, early_stop_count))
                if not early_stop_count:
                    self.load_model()
                    print('result preview...')
                    file_handler = logging.FileHandler(
                        os.path.join(cfg.exp_path,
                                     'eval_log%s.json' % cfg.seed))
                    logging.getLogger('').addHandler(file_handler)
                    logging.info(str(cfg))
                    self.eval()
                    return
                if not weight_decay_count:
                    lr *= cfg.lr_decay
                    self.optim = Adam(lr=lr,
                                      params=filter(lambda x: x.requires_grad,
                                                    self.m.parameters()),
                                      weight_decay=5e-5)
                    weight_decay_count = cfg.weight_decay_count
                    logging.info('learning rate decay, learning rate: %f' %
                                 (lr))
        self.load_model()
        print('result preview...')
        file_handler = logging.FileHandler(
            os.path.join(cfg.exp_path, 'eval_log%s.json' % cfg.seed))
        logging.getLogger('').addHandler(file_handler)
        logging.info(str(cfg))
        self.eval()

    def validate(self, data='dev', do_test=False):
        self.m.eval()
        valid_loss, count = 0, 0
        data_iterator = self.reader.get_batches(data)
        result_collection = {}
        for batch_num, dial_batch in enumerate(data_iterator):
            hidden_states = {}
            py_prev = {
                'pv_resp': None,
                'pv_bspn': None,
                'pv_aspn': None,
                'pv_dspn': None,
                'pv_bsdx': None
            }
            for turn_num, turn_batch in enumerate(dial_batch):
                first_turn = (turn_num == 0)
                inputs = self.reader.convert_batch(turn_batch,
                                                   py_prev,
                                                   first_turn=first_turn)
                inputs = self.add_torch_input(inputs, first_turn=first_turn)
                # total_loss, losses, hidden_states = self.m(inputs, hidden_states, first_turn, mode='train')
                if cfg.valid_loss not in ['score', 'match', 'success', 'bleu']:
                    total_loss, losses = self.m(inputs,
                                                hidden_states,
                                                first_turn,
                                                mode='train')
                    py_prev['pv_resp'] = turn_batch['resp']
                    if cfg.enable_bspn:
                        py_prev['pv_bspn'] = turn_batch['bspn']
                        py_prev['pv_bsdx'] = turn_batch['bsdx']
                    if cfg.enable_aspn:
                        py_prev['pv_aspn'] = turn_batch['aspn']
                    if cfg.enable_dspn:
                        py_prev['pv_dspn'] = turn_batch['dspn']

                    if cfg.valid_loss == 'total_loss':
                        valid_loss += float(total_loss)
                    elif cfg.valid_loss == 'bspn_loss':
                        valid_loss += float(losses[cfg.bspn_mode])
                    elif cfg.valid_loss == 'aspn_loss':
                        valid_loss += float(losses['aspn'])
                    elif cfg.valid_loss == 'resp_loss':
                        valid_loss += float(losses['reps'])
                    else:
                        raise ValueError('Invalid validation loss type!')
                else:
                    decoded = self.m(inputs,
                                     hidden_states,
                                     first_turn,
                                     mode='test')
                    turn_batch['resp_gen'] = decoded['resp']
                    if cfg.bspn_mode == 'bspn' or cfg.enable_dst:
                        turn_batch['bspn_gen'] = decoded['bspn']
                    py_prev['pv_resp'] = turn_batch[
                        'resp'] if cfg.use_true_pv_resp else decoded['resp']
                    if cfg.enable_bspn:
                        py_prev['pv_' + cfg.bspn_mode] = turn_batch[
                            cfg.
                            bspn_mode] if cfg.use_true_prev_bspn else decoded[
                                cfg.bspn_mode]
                        py_prev['pv_bspn'] = turn_batch[
                            'bspn'] if cfg.use_true_prev_bspn or 'bspn' not in decoded else decoded[
                                'bspn']
                    if cfg.enable_aspn:
                        py_prev['pv_aspn'] = turn_batch[
                            'aspn'] if cfg.use_true_prev_aspn else decoded[
                                'aspn']
                    if cfg.enable_dspn:
                        py_prev['pv_dspn'] = turn_batch[
                            'dspn'] if cfg.use_true_prev_dspn else decoded[
                                'dspn']

                    # TRADE
                    if cfg.enable_trade:
                        turn_batch["trade_ptr"] = decoded["trade_ptr"]
                        turn_batch["trade_gate"] = decoded["trade_gate"]

                count += 1
                torch.cuda.empty_cache()

            if cfg.valid_loss in ['score', 'match', 'success', 'bleu']:
                result_collection.update(
                    self.reader.inverse_transpose_batch(dial_batch))

        if cfg.valid_loss not in ['score', 'match', 'success', 'bleu']:
            valid_loss /= (count + 1e-8)
        else:
            results, _ = self.reader.wrap_result(
                result_collection)  # decode to sentence
            bleu, success, match = self.evaluator.validation_metric(results)
            score = 0.5 * (success + match) + bleu
            valid_loss = 200 - score
            logging.info(
                'validation [CTR] match: %2.1f  success: %2.1f  bleu: %2.1f' %
                (match, success, bleu))
        self.m.train()
        if do_test:
            print('result preview...')
            self.eval()
        return valid_loss

    def eval(self, data='test'):
        self.m.eval()
        self.reader.result_file = None
        result_collection = {}
        data_iterator = self.reader.get_batches(data)
        for batch_num, dial_batch in tqdm.tqdm(enumerate(data_iterator)):
            # quit()
            # if batch_num > 0:
            #     continue
            hidden_states = {}
            py_prev = {
                'pv_resp': None,
                'pv_bspn': None,
                'pv_aspn': None,
                'pv_dspn': None,
                'pv_bsdx': None
            }
            print('batch_size:', len(dial_batch[0]['resp']))
            for turn_num, turn_batch in enumerate(dial_batch):
                # print('turn %d'%turn_num)
                # if turn_num!=0 and turn_num<4:
                #     continue
                first_turn = (turn_num == 0)
                inputs = self.reader.convert_batch(turn_batch,
                                                   py_prev,
                                                   first_turn=first_turn)
                inputs = self.add_torch_input(inputs, first_turn=first_turn)
                decoded = self.m(inputs,
                                 hidden_states,
                                 first_turn,
                                 mode='test')
                turn_batch['resp_gen'] = decoded['resp']
                if cfg.bspn_mode == 'bsdx':
                    turn_batch['bsdx_gen'] = decoded[
                        'bsdx'] if cfg.enable_bspn else [[0]] * len(
                            decoded['resp'])
                if cfg.bspn_mode == 'bspn' or cfg.enable_dst:
                    turn_batch['bspn_gen'] = decoded[
                        'bspn'] if cfg.enable_bspn else [[0]] * len(
                            decoded['resp'])
                turn_batch['aspn_gen'] = decoded[
                    'aspn'] if cfg.enable_aspn else [[0]] * len(
                        decoded['resp'])
                turn_batch['dspn_gen'] = decoded[
                    'dspn'] if cfg.enable_dspn else [[0]] * len(
                        decoded['resp'])

                if self.reader.multi_acts_record is not None:
                    turn_batch['multi_act_gen'] = self.reader.multi_acts_record
                if cfg.record_mode:
                    turn_batch['multi_act'] = self.reader.aspn_collect
                    turn_batch['multi_resp'] = self.reader.resp_collect
                # print(turn_batch['user'])
                # print('user:'******'user'][0] , eos='<eos_u>', indicate_oov=True))
                # print('resp:', self.reader.vocab.sentence_decode(decoded['resp'][0] , eos='<eos_r>', indicate_oov=True))
                # print('bspn:', self.reader.vocab.sentence_decode(decoded['bspn'][0] , eos='<eos_b>', indicate_oov=True))
                # for b in range(len(decoded['resp'])):
                #     for i in range(5):
                #         print('aspn:', self.reader.vocab.sentence_decode(decoded['aspn'][i][b] , eos='<eos_a>', indicate_oov=True))

                py_prev['pv_resp'] = turn_batch[
                    'resp'] if cfg.use_true_pv_resp else decoded['resp']
                if cfg.enable_bspn:
                    py_prev['pv_' + cfg.bspn_mode] = turn_batch[
                        cfg.bspn_mode] if cfg.use_true_prev_bspn else decoded[
                            cfg.bspn_mode]
                    py_prev['pv_bspn'] = turn_batch[
                        'bspn'] if cfg.use_true_prev_bspn or 'bspn' not in decoded else decoded[
                            'bspn']
                if cfg.enable_aspn:
                    py_prev['pv_aspn'] = turn_batch[
                        'aspn'] if cfg.use_true_prev_aspn else decoded['aspn']
                if cfg.enable_dspn:
                    py_prev['pv_dspn'] = turn_batch[
                        'dspn'] if cfg.use_true_prev_dspn else decoded['dspn']

                # TRADE
                if cfg.enable_trade:
                    turn_batch["trade_ptr"] = decoded["trade_ptr"]
                    turn_batch["trade_gate"] = decoded["trade_gate"]

                torch.cuda.empty_cache()
                # prev_z = turn_batch['bspan']
            # print('test iter %d'%(batch_num+1))
            result_collection.update(
                self.reader.inverse_transpose_batch(dial_batch))

        # self.reader.result_file.close()
        if cfg.record_mode:
            self.reader.record_utterance(result_collection)
            quit()
        results, field = self.reader.wrap_result(result_collection)
        self.reader.save_result('w', results, field)

        metric_results = self.evaluator.run_metrics(results)
        metric_field = list(metric_results[0].keys())
        req_slots_acc = metric_results[0]['req_slots_acc']
        info_slots_acc = metric_results[0]['info_slots_acc']

        self.reader.save_result('w',
                                metric_results,
                                metric_field,
                                write_title='EVALUATION RESULTS:')
        self.reader.save_result('a', [info_slots_acc],
                                list(info_slots_acc.keys()),
                                write_title='INFORM ACCURACY OF EACH SLOTS:')
        self.reader.save_result('a', [req_slots_acc],
                                list(req_slots_acc.keys()),
                                write_title='REQUEST SUCCESS RESULTS:')
        self.reader.save_result('a',
                                results,
                                field +
                                ['wrong_domain', 'wrong_act', 'wrong_inform'],
                                write_title='DECODED RESULTS:')
        self.reader.save_result_report(metric_results)
        # self.reader.metric_record(metric_results)
        self.m.train()
        return None

    def save_model(self, epoch, path=None, critical=False):
        if not cfg.save_log:
            return
        if not path:
            path = cfg.model_path
        if critical:
            path += '.final'
        all_state = {
            'lstd': self.m.state_dict(),
            'config': cfg.__dict__,
            'epoch': epoch
        }
        torch.save(all_state, path)
        logging.info('Model saved')

    def load_model(self, path=None):
        if not path:
            path = cfg.model_path
        all_state = torch.load(path, map_location='cpu')
        self.m.load_state_dict(all_state['lstd'])
        self.base_epoch = all_state.get('epoch', 0)
        logging.info('Model loaded')

    def training_adjust(self, epoch):
        return

    def freeze_module(self, module):
        for param in module.parameters():
            param.requires_grad = False

    def unfreeze_module(self, module):
        for param in module.parameters():
            param.requires_grad = True

    def load_glove_embedding(self, freeze=False):
        if not cfg.multi_gpu:
            initial_arr = self.m.embedding.weight.data.cpu().numpy()
            emb = torch.from_numpy(
                utils.get_glove_matrix(cfg.glove_path, self.reader.vocab,
                                       initial_arr))
            self.m.embedding.weight.data.copy_(emb)
        else:
            initial_arr = self.m.module.embedding.weight.data.cpu().numpy()
            emb = torch.from_numpy(
                utils.get_glove_matrix(cfg.glove_path, self.reader.vocab,
                                       initial_arr))
            self.m.module.embedding.weight.data.copy_(emb)

    def count_params(self):
        module_parameters = filter(lambda p: p.requires_grad,
                                   self.m.parameters())
        param_cnt = int(sum([np.prod(p.size()) for p in module_parameters]))

        print('total trainable params: %d' % param_cnt)
        return param_cnt
Ejemplo n.º 5
0
def dialog_turn_state_analysis(mode='train'):
    data_path = 'data/multi-woz-processed/data_for_damd.json'
    conv_data = 'data/multi-woz/annotated_user_da_with_span_full.json'
    archive = zipfile.ZipFile(conv_data + '.zip', 'r')
    convlab_data = json.loads(archive.open(conv_data.split('/')[-1], 'r').read().lower())
    reader = MultiWozReader()
    data = json.loads(open(data_path, 'r', encoding='utf-8').read().lower())

    turn_state_record, turn_state_count, golden_acts = {}, {}, {}
    act_state_collect = []
    act_state_detail = {}
    state_valid_acts = {}
    dial_count = 0
    turn_count = 0

    for fn, dial in data.items():
        dial_count += 1
        state_valid_acts[fn] = {}
        for turn_no, turn in enumerate(dial['log']):
            turn_state = {}
            turn_domain = turn['turn_domain'].split()
            cons_delex = turn['cons_delex'].split()
            sys_act = turn['sys_act']
            usr_act = convlab_data[fn]['log'][turn_no * 2]['dialog_act']
            db_ptr = [int(i) for i in turn['pointer'].split(',')]
            match = turn['match']
            if len(turn_domain) != 1 or turn_domain[0] == '[general]' or not sys_act:
                continue
            state_valid_acts[fn][turn_no] = {}
            turn_count += 1

            slot_mentioned = []
            for idx, tk in enumerate(cons_delex[:-1]):
                if tk in turn_domain:
                    i = idx+1
                    while i < len(cons_delex):
                        if '[' not in cons_delex[i]:
                            slot_mentioned.append(cons_delex[i])
                        else:
                            break
                        i = i+1
            slot_mentioned.sort()
            # turn_state['slot_mentioned'] = len(slot_mentioned)
            turn_state['domain'] = turn_domain
            turn_state['slot_mentioned'] = slot_mentioned
            if match == '':
                turn_state['match']=''
            elif match == '0':
                turn_state['match']='0'
            elif match == '1':
                turn_state['match'] = '1'
            elif match == '2' or match == '3':
                turn_state['match'] = '2-3'
            else:
                turn_state['match']='>3'
            if db_ptr[-2:] == [0,0]:
                turn_state['book'] = ''
            elif db_ptr[-2:] == [1,0]:
                turn_state['book'] = 'no'
            else:
                turn_state['book'] = 'yes'

            turn_state['usract'] = []
            for act in usr_act:
                d, a = act.split('-')
                if a not in turn_state['usract']:
                    slot_list = []
                    if a == 'request':
                        for slot_value in usr_act[act]:

                            slot = slot_value[0]

                            if slot == 'none':
                                continue
                            elif slot not in slot_list:
                                slot = ontology.da_abbr_to_slot_name.get(slot, slot)
                                slot_list.append(slot)
                    if not slot_list:
                        turn_state['usract'].append(a)
                    else:
                        slot_list.sort()
                        turn_state['usract'].append(a+'('+','.join(slot_list)+')')
            turn_state['usract'].sort()

            turn_state_str = ''
            for k,v in turn_state.items():
                if isinstance(v, list):
                    v_ = ','.join(v)
                elif isinstance(v, int):
                    v_ = str(v)
                else:
                    v_ = v
                turn_state_str += '%s(%s);'%(k, v_)
            turn_state_str = turn_state_str[:-1]
            state_valid_acts[fn][turn_no]['usdx'] = turn['user_delex']
            state_valid_acts[fn][turn_no]['state'] = turn_state_str


            if sys_act not in act_state_detail:
                act_state_detail[sys_act] = 1
            act_list = reader.aspan_to_act_list(sys_act)
            act_state = {'domain': {}, 'general': {}}
            for act in act_list:
                d, a, s = act.split('-')
                if d == 'general':
                    act_state['general'][a] = ''
                else:
                    if a not in act_state['domain']:
                        if s != 'none':
                            act_state['domain'][a] = ''
                        else:
                            act_state['domain'][a] = ''
                    else:
                        act_state['domain'][a] = ''

            no_order_act = {}
            for a in act_list:
                no_order_act[a] = 1

            act_state_str = ''
            for k,v in act_state.items():
                if isinstance(v, dict):
                    v_ = ''
                    for kk, vv in v.items():
                        v_ += kk+'(%s),'%str(vv)
                    if v_.endswith(','):
                        v_ = v_[:-1]
                elif isinstance(v, int):
                    v_ = str(v)
                else:
                    v_ = v
                if v_ != '':
                    act_state_str += '%s(%s);'%(k, v_)
            act_state_str = act_state_str[:-1]
            state_valid_acts[fn][turn_no]['gold'] = {}
            state_valid_acts[fn][turn_no]['gold'][act_state_str] = {}
            state_valid_acts[fn][turn_no]['gold'][act_state_str]['resp'] = turn['resp']
            state_valid_acts[fn][turn_no]['gold'][act_state_str]['act'] = sys_act

            if mode == 'test' and fn not in reader.test_files:
                continue
            if mode == 'train' and fn in reader.test_files:
                continue
            if act_state not in act_state_collect:
                act_state_collect.append(act_state)
            new_state = True if turn_state_str not in turn_state_record else False
            raw_sys_rec  = fn+'-'+str(turn_no)+':'+sys_act
            if new_state:
                turn_state_record[turn_state_str] = {act_state_str: {'num': 1, 'raw_acts': [raw_sys_rec], 'no_order_act': [no_order_act],
                                                                         'user': [turn['user']], 'resp': [turn['resp']]}}
                golden_acts[turn_state_str] = {'act_span': raw_sys_rec, 'no_order_act': no_order_act}
                turn_state_count[turn_state_str] = 1
            else:
                turn_state_count[turn_state_str] += 1
                if act_state_str in turn_state_record[turn_state_str]:
                    if no_order_act == golden_acts[turn_state_str]['no_order_act']:
                        continue
                    if no_order_act in turn_state_record[turn_state_str][act_state_str]['no_order_act']:
                        continue
                    turn_state_record[turn_state_str][act_state_str]['num'] +=1
                    turn_state_record[turn_state_str][act_state_str]['raw_acts'].append(raw_sys_rec)
                    turn_state_record[turn_state_str][act_state_str]['user'].append(turn['user'])
                    turn_state_record[turn_state_str][act_state_str]['resp'].append(turn['resp'])
                    turn_state_record[turn_state_str][act_state_str]['no_order_act'].append(no_order_act)
                else:
                    turn_state_record[turn_state_str][act_state_str] = {'num': 1, 'raw_acts': [raw_sys_rec], 'no_order_act': [no_order_act],
                                                                                                    'user': [turn['user']], 'resp': [turn['resp']]}
    for state, acts in turn_state_record.items():
        turn_state_record[state] = OrderedDict(sorted(acts.items(), key=lambda i:i[1]['num'], reverse=True))

    # print(mode)
    print('dialog count:', dial_count, 'turn count: ',turn_count)
    print('state count:', len(turn_state_record))
    print('raw act span count:', len(act_state_detail))
    print('act state count:', len(act_state_collect))


    for fn, dial in data.items():
        if fn in reader.dev_files or fn in reader.test_files:
            continue
        dial_count += 1
        for turn_no, turn in enumerate(dial['log']):
            if turn_no not in state_valid_acts[fn]:
                continue
            state = state_valid_acts[fn][turn_no]['state']
            gold_act_type = list(state_valid_acts[fn][turn_no]['gold'].keys())[0]
            state_valid_acts[fn][turn_no]['other'] = {}
            if state in turn_state_record:
                for act_type in turn_state_record[state]:
                    if act_type == gold_act_type:
                        continue
                    state_valid_acts[fn][turn_no]['other'][act_type] = []
                    for idx, a in enumerate(turn_state_record[state][act_type]['raw_acts']):
                        m = {'act': a}
                        m['resp'] = turn_state_record[state][act_type]['resp'][idx]
                        state_valid_acts[fn][turn_no]['other'][act_type].append(m)

    # sub_state_valid_acts = {}
    # count = 0
    # for fn, dial in state_valid_acts.items():
    #     if 'mul' in fn and fn not in reader.test_files and count<=100:
    #         sub_state_valid_acts[fn] = dial
    #         count += 1
    #     if count >100:
    #         break
    # with open('data/multi-woz-processed/example_multi_act_dialogs.json', 'w') as f:
    #     json.dump(sub_state_valid_acts, f, indent=2)

    idx_save = {}
    act_span_save = {}
    hist = []
    for fn, dial in state_valid_acts.items():
        if fn in reader.dev_files or fn in reader.test_files:
            continue
        act_span_save[fn] = {}
        idx_save[fn] = {}
        for turn_num, turn in dial.items():
            act_span_save[fn][turn_num] = {}
            idx_save[fn][turn_num] = []
            for act_type, acts in turn['other'].items():
                hist.append(len(acts)+1)
                act_span_save[fn][turn_num][act_type] = [a['act'].split(':')[1] for a in acts]
                idx_save[fn][turn_num].append([a['act'].split(':')[0] for a in acts])


    with open('data/multi-woz-processed/multi_act_mapping_%s.json'%mode, 'w') as f:
        json.dump(act_span_save, f, indent=2)