Пример #1
0
 def __init__(self, opt, state_dict=None, num_train_step=-1):
     self.config = opt
     self.updates = state_dict[
         'updates'] if state_dict and 'updates' in state_dict else 0
     self.local_updates = 0
     self.train_loss = AverageMeter()
     self.initial_from_local = True if state_dict else False
     self.network = SANBertNetwork(
         opt, initial_from_local=self.initial_from_local)
     if state_dict:
         missing_keys, unexpected_keys = self.network.load_state_dict(
             state_dict['state'], strict=False)
     self.mnetwork = nn.DataParallel(
         self.network) if opt['multi_gpu_on'] else self.network
     self.total_param = sum([
         p.nelement() for p in self.network.parameters() if p.requires_grad
     ])
     if opt['cuda']:
         self.network.cuda()
     optimizer_parameters = self._get_param_groups()
     self._setup_optim(optimizer_parameters, state_dict, num_train_step)
     self.para_swapped = False
     self.optimizer.zero_grad()
     self._setup_lossmap(self.config)
     self._setup_kd_lossmap(self.config)
Пример #2
0
    def __init__(self, opt, device=None, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict[
            'updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.device = device
        self.train_loss = AverageMeter()
        self.initial_from_local = True if state_dict else False
        model = SANBertNetwork(opt, initial_from_local=self.initial_from_local)
        self.total_param = sum(
            [p.nelement() for p in model.parameters() if p.requires_grad])
        if opt['cuda']:
            if self.config['local_rank'] != -1:
                model = model.to(self.device)
            else:
                model = model.to(self.device)
        self.network = model
        if state_dict:
            missing_keys, unexpected_keys = self.network.load_state_dict(
                state_dict['state'], strict=False)

        optimizer_parameters = self._get_param_groups()
        self._setup_optim(optimizer_parameters, state_dict, num_train_step)
        self.optimizer.zero_grad()

        #if self.config["local_rank"] not in [-1, 0]:
        #    torch.distributed.barrier()

        if self.config['local_rank'] != -1:
            self.mnetwork = torch.nn.parallel.DistributedDataParallel(
                self.network,
                device_ids=[self.config["local_rank"]],
                output_device=self.config["local_rank"],
                find_unused_parameters=True)
        elif self.config['multi_gpu_on']:
            self.mnetwork = nn.DataParallel(self.network)
        else:
            self.mnetwork = self.network
        self._setup_lossmap(self.config)
        self._setup_kd_lossmap(self.config)
        self._setup_adv_lossmap(self.config)
        self._setup_adv_training(self.config)
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict['updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)
        if state_dict:
            self.network.load_state_dict(state_dict['state'], strict=False)
        self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])
        if opt['cuda']:
            self.network.cuda()
        optimizer_parameters = self._get_param_groups()
        self._setup_optim(optimizer_parameters, state_dict, num_train_step)
        self.para_swapped = False
        self.optimizer.zero_grad()
        self._setup_lossmap(self.config)
        self.sampler = None

    def _get_param_groups(self):
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in self.network.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        return optimizer_parameters

    def _setup_optim(self, optimizer_parameters, state_dict=None, num_train_step=-1):
        if self.config['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters, self.config['learning_rate'],
                                       weight_decay=self.config['weight_decay'])

        elif self.config['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        elif self.config['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    eps=self.config['adam_eps'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            self.config['fp16'] = False
        elif self.config['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=self.config['learning_rate'],
                                  warmup=self.config['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=self.config['grad_clipping'],
                                  schedule=self.config['warmup_schedule'],
                                  weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if self.config['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.config['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if self.config.get('have_lr_scheduler', False):
            if self.config.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_gamma'], patience=3)
            elif self.config.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=self.config.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in self.config.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=self.config.get('lr_gamma'))
        else:
            self.scheduler = None

    def _setup_lossmap(self, config):
        loss_types = config['loss_types']
        self.task_loss_criterion = []
        for idx, cs in enumerate(loss_types):
            assert cs is not None
            lc = LOSS_REGISTRY[cs](name='Loss func of task {}: {}'.format(idx, cs))
            self.task_loss_criterion.append(lc)

    def _setup_kd_lossmap(self, config):
        loss_types = config['kd_loss_types']
        self.kd_task_loss_criterion = []
        if config.get('mkd_opt', 0) > 0:
            for idx, cs in enumerate(loss_types):
                assert cs is not None
                lc = LOSS_REGISTRY[cs](name='Loss func of task {}: {}'.format(idx, cs))
                self.kd_task_loss_criterion.append(lc)

    def train(self):
        if self.para_swapped:
            self.para_swapped = False

    def _to_cuda(self, tensor):
        if tensor is None: return tensor

        if isinstance(tensor, list) or isinstance(tensor, tuple):
            y = [e.cuda(non_blocking=True) for e in tensor]
            for e in y:
                e.requires_grad = False
        else:
            y = tensor.cuda(non_blocking=True)
            y.requires_grad = False
        return y

    def update(self, batch_meta, batch_data):
        self.network.train()
        y = batch_data[batch_meta['label']]
        soft_labels = None

        task_type = batch_meta['task_type']
        y = self._to_cuda(y) if self.config['cuda'] else y

        task_id = batch_meta['task_id']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        weight = None
        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
            else:
                weight = batch_data[batch_meta['factor']]
        logits = self.mnetwork(*inputs)

        # compute loss
        loss = 0
        if self.task_loss_criterion[task_id] and (y is not None):
            loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

        # compute kd loss
        if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
            soft_labels = batch_meta['soft_label']
            soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
            kd_lc = self.kd_task_loss_criterion[task_id]
            kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
            loss = loss + kd_loss

        self.train_loss.update(loss.item(), batch_data[batch_meta['token_id']].size(0))
        # scale loss
        loss = loss / self.config.get('grad_accumulation_step', 1)
        if self.sampler:
            self.sampler.update_loss(task_id, loss.item())
        if self.config['fp16']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.get('grad_accumulation_step', 1) == 0:
            if self.config['global_grad_clipping'] > 0:
                if self.config['fp16']:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                                   self.config['global_grad_clipping'])
                else:
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                                  self.config['global_grad_clipping'])
            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()

    def update_all(self, batch_metas, batch_datas):
        self.network.train()
        task_losses = []
        max_loss = 0
        max_task = 0
        with torch.no_grad():
            for i, (batch_meta, batch_data) in enumerate(zip(batch_metas, batch_datas)):
                y = batch_data[batch_meta['label']]
                soft_labels = None

                task_type = batch_meta['task_type']
                y = self._to_cuda(y) if self.config['cuda'] else y

                task_id = batch_meta['task_id']
                assert task_id == i
                inputs = batch_data[:batch_meta['input_len']]
                if len(inputs) == 3:
                    inputs.append(None)
                    inputs.append(None)
                inputs.append(task_id)
                weight = None
                if self.config.get('weighted_on', False):
                    if self.config['cuda']:
                        weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
                    else:
                        weight = batch_data[batch_meta['factor']]
                logits = self.mnetwork(*inputs)

                # compute loss
                loss = 0
                if self.task_loss_criterion[task_id] and (y is not None):
                    loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

                # compute kd loss
                if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
                    soft_labels = batch_meta['soft_label']
                    soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
                    kd_lc = self.kd_task_loss_criterion[task_id]
                    kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
                    loss = loss + kd_loss
                
                if loss > max_loss:
                    max_loss = loss
                    max_task = i
                task_losses.append(loss.item())
        # print(task_losses)
        p = np.array(task_losses)
        p /= p.sum()
        sampled_id = 0
        if random.random() < 0.3:
            sampled_id = max_task
        else:
            sampled_id = np.random.choice(list(range(len(task_losses))), p=p, replace=False)
        # print("task ", sampled_id, " is selected")
        self.loss_list = task_losses
        self.update(batch_metas[sampled_id], batch_datas[sampled_id])

    def calculate_loss(self, batch_meta, batch_data):
        self.network.train()
        with torch.no_grad():
            y = batch_data[batch_meta['label']]
            soft_labels = None

            task_type = batch_meta['task_type']
            y = self._to_cuda(y) if self.config['cuda'] else y

            task_id = batch_meta['task_id']
            
            inputs = batch_data[:batch_meta['input_len']]
            if len(inputs) == 3:
                inputs.append(None)
                inputs.append(None)
            inputs.append(task_id)
            weight = None
            if self.config.get('weighted_on', False):
                if self.config['cuda']:
                    weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
                else:
                    weight = batch_data[batch_meta['factor']]
            logits = self.mnetwork(*inputs)

            # compute loss
            loss = 0
            if self.task_loss_criterion[task_id] and (y is not None):
                loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

            # compute kd loss
            if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
                soft_labels = batch_meta['soft_label']
                soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
                kd_lc = self.kd_task_loss_criterion[task_id]
                kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
                loss = loss + kd_loss
        return loss   


    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        elif task_type == TaskType.SeqenceLabeling:
            mask = batch_data[batch_meta['mask']]
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
            valied_lenght = mask.sum(1).tolist()
            final_predict = []
            for idx, p in enumerate(predict):
                final_predict.append(p[: valied_lenght[idx]])
            score = score.reshape(-1).tolist()
            return score, final_predict, batch_meta['label']
        elif task_type == TaskType.Span:
            start, end = score
            predictions = []
            if self.config['encoder_type'] == EncoderModelType.BERT:
                import experiments.squad.squad_utils as mrc_utils
                scores, predictions = mrc_utils.extract_answer(batch_meta, batch_data,start, end, self.config.get('max_answer_len', 5))
            return scores, predictions, batch_meta['answer']
        else:
            if task_type == TaskType.Classification:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta['label']

    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def save(self, filename):
        network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def load(self, checkpoint):
        model_state_dict = torch.load(checkpoint)
        self.network.load_state_dict(model_state_dict['state'], strict=False)
        self.optimizer.load_state_dict(model_state_dict['optimizer'])
        self.config.update(model_state_dict['config'])

    def cuda(self):
        self.network.cuda()

        
Пример #4
0
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict[
            'updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)

        if state_dict:
            self.network.load_state_dict(state_dict['state'], strict=False)
        self.mnetwork = nn.DataParallel(
            self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([
            p.nelement() for p in self.network.parameters() if p.requires_grad
        ])
        if opt['cuda']:
            self.network.cuda()

        no_decay = [
            'bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight'
        ]

        optimizer_parameters = [{
            'params': [
                p for n, p in self.network.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params': [
                p for n, p in self.network.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        # note that adamax are modified based on the BERT code
        if opt['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters,
                                       opt['learning_rate'],
                                       weight_decay=opt['weight_decay'])

        elif opt['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    opt['learning_rate'],
                                    warmup=opt['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=opt['grad_clipping'],
                                    schedule=opt['warmup_schedule'],
                                    weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
        elif opt['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                   opt['learning_rate'],
                                   warmup=opt['warmup'],
                                   t_total=num_train_step,
                                   max_grad_norm=opt['grad_clipping'],
                                   schedule=opt['warmup_schedule'],
                                   eps=opt['adam_eps'],
                                   weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            opt['fp16'] = False
        elif opt['optimizer'] == 'adadelta':
            self.optimizer = optim.Adadelta(optimizer_parameters,
                                            opt['learning_rate'],
                                            rho=0.95)
        elif opt['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=opt['learning_rate'],
                                  warmup=opt['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=opt['grad_clipping'],
                                  schedule=opt['warmup_schedule'],
                                  weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if opt['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(self.network,
                                              self.optimizer,
                                              opt_level=opt['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if opt.get('have_lr_scheduler', False):
            if opt.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer,
                                                   mode='max',
                                                   factor=opt['lr_gamma'],
                                                   patience=3)
            elif opt.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer,
                                               gamma=opt.get('lr_gamma', 0.95))
            else:
                milestones = [
                    int(step)
                    for step in opt.get('multi_step_lr', '10,20,30').split(',')
                ]
                self.scheduler = MultiStepLR(self.optimizer,
                                             milestones=milestones,
                                             gamma=opt.get('lr_gamma'))
        else:
            self.scheduler = None

        self.ema = None
        if opt['ema_opt'] > 0:
            self.ema = EMA(self.config['ema_gamma'], self.network)
            if opt['cuda']:
                self.ema.cuda()

        self.para_swapped = False
        # zero optimizer grad
        self.optimizer.zero_grad()
Пример #5
0
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict[
            'updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)

        if state_dict:
            self.network.load_state_dict(state_dict['state'], strict=False)
        self.mnetwork = nn.DataParallel(
            self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([
            p.nelement() for p in self.network.parameters() if p.requires_grad
        ])
        if opt['cuda']:
            self.network.cuda()

        no_decay = [
            'bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight'
        ]

        optimizer_parameters = [{
            'params': [
                p for n, p in self.network.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params': [
                p for n, p in self.network.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        # note that adamax are modified based on the BERT code
        if opt['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters,
                                       opt['learning_rate'],
                                       weight_decay=opt['weight_decay'])

        elif opt['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    opt['learning_rate'],
                                    warmup=opt['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=opt['grad_clipping'],
                                    schedule=opt['warmup_schedule'],
                                    weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
        elif opt['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                   opt['learning_rate'],
                                   warmup=opt['warmup'],
                                   t_total=num_train_step,
                                   max_grad_norm=opt['grad_clipping'],
                                   schedule=opt['warmup_schedule'],
                                   eps=opt['adam_eps'],
                                   weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            opt['fp16'] = False
        elif opt['optimizer'] == 'adadelta':
            self.optimizer = optim.Adadelta(optimizer_parameters,
                                            opt['learning_rate'],
                                            rho=0.95)
        elif opt['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=opt['learning_rate'],
                                  warmup=opt['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=opt['grad_clipping'],
                                  schedule=opt['warmup_schedule'],
                                  weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if opt['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(self.network,
                                              self.optimizer,
                                              opt_level=opt['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if opt.get('have_lr_scheduler', False):
            if opt.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer,
                                                   mode='max',
                                                   factor=opt['lr_gamma'],
                                                   patience=3)
            elif opt.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer,
                                               gamma=opt.get('lr_gamma', 0.95))
            else:
                milestones = [
                    int(step)
                    for step in opt.get('multi_step_lr', '10,20,30').split(',')
                ]
                self.scheduler = MultiStepLR(self.optimizer,
                                             milestones=milestones,
                                             gamma=opt.get('lr_gamma'))
        else:
            self.scheduler = None

        self.ema = None
        if opt['ema_opt'] > 0:
            self.ema = EMA(self.config['ema_gamma'], self.network)
            if opt['cuda']:
                self.ema.cuda()

        self.para_swapped = False
        # zero optimizer grad
        self.optimizer.zero_grad()

    def setup_ema(self):
        if self.config['ema_opt']:
            self.ema.setup()

    def update_ema(self):
        if self.config['ema_opt']:
            self.ema.update()

    def eval(self):
        if self.config['ema_opt']:
            self.ema.swap_parameters()
            self.para_swapped = True

    def train(self):
        if self.para_swapped:
            self.ema.swap_parameters()
            self.para_swapped = False

    def update(self, batch_meta, batch_data):
        self.network.train()
        labels = batch_data[batch_meta['label']]
        soft_labels = None
        if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
            soft_labels = batch_meta['soft_label']

        task_type = batch_meta['task_type']
        if task_type == TaskType.Span:
            start = batch_data[batch_meta['start']]
            end = batch_data[batch_meta['end']]
            if self.config["cuda"]:
                start = start.cuda(non_blocking=True)
                end = end.cuda(non_blocking=True)
            start.requires_grad = False
            end.requires_grad = False
        else:
            y = labels
            if task_type == TaskType.Ranking:
                y = y.contiguous().view(-1, batch_meta['pairwise_size'])[:, 0]
            if self.config['cuda']:
                y = y.cuda(non_blocking=True)
            y.requires_grad = False

        task_id = batch_meta['task_id']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)

        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = batch_data[batch_meta['factor']].cuda(
                    non_blocking=True)
            else:
                weight = batch_data[batch_meta['factor']]

        if task_type == TaskType.Span:
            start_logits, end_logits = self.mnetwork(*inputs)
            ignored_index = start_logits.size(1)
            start.clamp_(0, ignored_index)
            end.clamp_(0, ignored_index)
            if self.config.get('weighted_on', False):
                loss = torch.mean(F.cross_entropy(start_logits, start, reduce=False) * weight) + \
                    torch.mean(F.cross_entropy(end_logits, end, reduce=False) * weight)
            else:
                loss = F.cross_entropy(start_logits, start, ignore_index=ignored_index) + \
                    F.cross_entropy(end_logits, end, ignore_index=ignored_index)
            loss = loss / 2
        else:
            logits = self.mnetwork(*inputs)
            if task_type == TaskType.Ranking:
                logits = logits.view(-1, batch_meta['pairwise_size'])
            if self.config.get('weighted_on', False):
                if task_type == TaskType.Regression:
                    loss = torch.mean(
                        F.mse_loss(logits.squeeze(), y, reduce=False) * weight)
                else:
                    loss = torch.mean(
                        F.cross_entropy(logits, y, reduce=False) * weight)
                    if soft_labels is not None:
                        # compute KL
                        label_size = soft_labels.size(1)
                        kd_loss = F.kl_div(F.log_softmax(
                            logits.view(-1, label_size).float(), 1),
                                           soft_labels,
                                           reduction='batchmean')
                        loss = loss + kd_loss
            else:
                if task_type == TaskType.Regression:
                    loss = F.mse_loss(logits.squeeze(), y)
                else:
                    loss = F.cross_entropy(logits, y)
                    if soft_labels is not None:
                        # compute KL
                        label_size = soft_labels.size(1)
                        kd_loss = F.kl_div(F.log_softmax(
                            logits.view(-1, label_size).float(), 1),
                                           soft_labels,
                                           reduction='batchmean')
                        loss = loss + kd_loss

        self.train_loss.update(loss.item(), logits.size(0))
        # scale loss
        loss = loss / self.config.get('grad_accumulation_step', 1)
        if self.config['fp16']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.get('grad_accumulation_step',
                                                1) == 0:
            if self.config['global_grad_clipping'] > 0:
                if self.config['fp16']:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer),
                        self.config['global_grad_clipping'])
                else:
                    torch.nn.utils.clip_grad_norm_(
                        self.network.parameters(),
                        self.config['global_grad_clipping'])

            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.update_ema()

    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        else:
            if task_type == TaskType.Classification:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta['label']

    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def save(self, filename):
        network_state = dict([(k, v.cpu())
                              for k, v in self.network.state_dict().items()])
        ema_state = dict([
            (k, v.cpu()) for k, v in self.ema.model.state_dict().items()
        ]) if self.ema is not None else dict()
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'ema': ema_state,
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def load(self, checkpoint):

        model_state_dict = torch.load(checkpoint)
        if model_state_dict['config']['init_checkpoint'].rsplit('/', 1)[1] != \
                self.config['init_checkpoint'].rsplit('/', 1)[1]:
            logger.error(
                '*** SANBert network is pretrained on a different Bert Model. Please use that to fine-tune for other tasks. ***'
            )
            sys.exit()

        self.network.load_state_dict(model_state_dict['state'], strict=False)
        self.optimizer.load_state_dict(model_state_dict['optimizer'])
        self.config = model_state_dict['config']
        if self.ema:
            self.ema.model.load_state_dict(model_state_dict['ema'])

    def cuda(self):
        self.network.cuda()
        if self.config['ema_opt']:
            self.ema.cuda()
Пример #6
0
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict['updates'] if state_dict and 'updates' in state_dict else 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)

        if state_dict:
            new_state = set(self.network.state_dict().keys())
            for k in list(state_dict['state'].keys()):
                if k not in new_state:
                    del state_dict['state'][k]
            for k, v in list(self.network.state_dict().items()):
                if k not in state_dict['state']:
                    state_dict['state'][k] = v
            self.network.load_state_dict(state_dict['state'])
        self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])

        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in self.network.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
            ]
        # note that adamax are modified based on the BERT code
        if opt['optimizer'] == 'sgd':
            self.optimizer = optim.sgd(optimizer_parameters, opt['learning_rate'],
                                       weight_decay=opt['weight_decay'])

        elif opt['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                        opt['learning_rate'],
                                        warmup=opt['warmup'],
                                        t_total=num_train_step,
                                        max_grad_norm=opt['grad_clipping'],
                                        schedule=opt['warmup_schedule'])
            if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False
        elif opt['optimizer'] == 'adadelta':
            self.optimizer = optim.Adadelta(optimizer_parameters,
                                            opt['learning_rate'],
                                            rho=0.95)
        elif opt['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                        lr=opt['learning_rate'],
                                        warmup=opt['warmup'],
                                        t_total=num_train_step,
                                        max_grad_norm=opt['grad_clipping'],
                                        schedule=opt['warmup_schedule'])
            if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if opt.get('have_lr_scheduler', False):
            if opt.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=opt['lr_gamma'], patience=3)
            elif opt.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=opt.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma'))
        else:
            self.scheduler = None
        self.ema = None
        if opt['ema_opt'] > 0:
            self.ema = EMA(self.config['ema_gamma'], self.network)
        self.para_swapped = False
Пример #7
0
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict['updates'] if state_dict and 'updates' in state_dict else 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)

        if state_dict:
            new_state = set(self.network.state_dict().keys())
            for k in list(state_dict['state'].keys()):
                if k not in new_state:
                    del state_dict['state'][k]
            for k, v in list(self.network.state_dict().items()):
                if k not in state_dict['state']:
                    state_dict['state'][k] = v
            self.network.load_state_dict(state_dict['state'])
        self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])

        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in self.network.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
            ]
        # note that adamax are modified based on the BERT code
        if opt['optimizer'] == 'sgd':
            self.optimizer = optim.sgd(optimizer_parameters, opt['learning_rate'],
                                       weight_decay=opt['weight_decay'])

        elif opt['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                        opt['learning_rate'],
                                        warmup=opt['warmup'],
                                        t_total=num_train_step,
                                        max_grad_norm=opt['grad_clipping'],
                                        schedule=opt['warmup_schedule'])
            if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False
        elif opt['optimizer'] == 'adadelta':
            self.optimizer = optim.Adadelta(optimizer_parameters,
                                            opt['learning_rate'],
                                            rho=0.95)
        elif opt['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                        lr=opt['learning_rate'],
                                        warmup=opt['warmup'],
                                        t_total=num_train_step,
                                        max_grad_norm=opt['grad_clipping'],
                                        schedule=opt['warmup_schedule'])
            if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if opt.get('have_lr_scheduler', False):
            if opt.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=opt['lr_gamma'], patience=3)
            elif opt.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=opt.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma'))
        else:
            self.scheduler = None
        self.ema = None
        if opt['ema_opt'] > 0:
            self.ema = EMA(self.config['ema_gamma'], self.network)
        self.para_swapped = False

    def setup_ema(self):
        if self.config['ema_opt']:
            self.ema.setup()

    def update_ema(self):
        if self.config['ema_opt']:
            self.ema.update()

    def eval(self):
        if self.config['ema_opt']:
            self.ema.swap_parameters()
            self.para_swapped = True

    def train(self):
        if self.para_swapped:
            self.ema.swap_parameters()
            self.para_swapped = False

    def update(self, batch_meta, batch_data):
        self.network.train()
        labels = batch_data[batch_meta['label']]
        if batch_meta['pairwise']:
            labels = labels.contiguous().view(-1, batch_meta['pairwise_size'])[:, 0]
        if self.config['cuda']:
            y = Variable(labels.cuda(async=True), requires_grad=False)
        else:
            y = Variable(labels, requires_grad=False)
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        logits = self.mnetwork(*inputs)
        if batch_meta['pairwise']:
            logits = logits.view(-1, batch_meta['pairwise_size'])

        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = Variable(batch_data[batch_meta['factor']].cuda(async=True))
            else:
                weight = Variable(batch_data[batch_meta['factor']])
            if task_type > 0:
                loss = torch.mean(F.mse_loss(logits.squeeze(), y, reduce=False) * weight)
            else:
                loss = torch.mean(F.cross_entropy(logits, y, reduce=False) * weight)
        else:
            if task_type > 0:
                loss = F.mse_loss(logits.squeeze(), y)
            else:
                loss = F.cross_entropy(logits, y)

        self.train_loss.update(loss.item(), logits.size(0))
        self.optimizer.zero_grad()

        loss.backward()
        if self.config['global_grad_clipping'] > 0:
            torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                          self.config['global_grad_clipping'])
        self.optimizer.step()
        self.updates += 1
        self.update_ema()

    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if batch_meta['pairwise']:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            if task_type < 1:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        else:
            if task_type < 1:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta['label']

    def save(self, filename):
        network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        ema_state = dict(
            [(k, v.cpu()) for k, v in self.ema.model.state_dict().items()]) if self.ema is not None else dict()
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'ema': ema_state,
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def cuda(self):
        self.network.cuda()
        if self.config['ema_opt']:
            self.ema.cuda()
Пример #8
0
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict['updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.initial_from_local = True if state_dict else False
        self.network = SANBertNetwork(opt, initial_from_local=self.initial_from_local)
        if state_dict:
            missing_keys, unexpected_keys = self.network.load_state_dict(state_dict['state'], strict=False)
        self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])
        if opt['cuda']:
            self.network.cuda()
        optimizer_parameters = self._get_param_groups()
        #print(optimizer_parameters)
        self._setup_optim(optimizer_parameters, state_dict, num_train_step) 
        self.para_swapped = False
        self.optimizer.zero_grad()
        self._setup_lossmap(self.config)
        self._setup_kd_lossmap(self.config)
        self._setup_adv_lossmap(self.config)
        self._setup_adv_training(self.config)


    def _setup_adv_training(self, config):
        self.adv_teacher = None
        if config.get('adv_train', False):
            self.adv_teacher = SmartPerturbation(config['adv_epsilon'],
                    config['multi_gpu_on'],
                    config['adv_step_size'],
                    config['adv_noise_var'],
                    config['adv_p_norm'],
                    config['adv_k'],
                    config['fp16'],
                    config['encoder_type'],
                    loss_map=self.adv_task_loss_criterion)


    def _get_param_groups(self):
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in self.network.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        return optimizer_parameters

    def _setup_optim(self, optimizer_parameters, state_dict=None, num_train_step=-1): ###여기서 Error
        #print(len(optimizer_parameters[0]['params']))
        if self.config['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters, self.config['learning_rate'],
                                       weight_decay=self.config['weight_decay'])

        elif self.config['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        elif self.config['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    eps=self.config['adam_eps'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            self.config['fp16'] = False
        elif self.config['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=self.config['learning_rate'],
                                  warmup=self.config['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=self.config['grad_clipping'],
                                  schedule=self.config['warmup_schedule'],
                                  weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])
        
        if state_dict and 'optimizer' in state_dict:
            #print("Optimizer's state_dict:")
            #state_dict['optimizer']['param_groups'][0]['params']=state_dict['optimizer']['param_groups'][0]['params'][:77]
            #print(len(state_dict['optimizer']['param_groups'][0]['params']))
            #for var_name in state_dict['optimizer']:
            #    print(var_name, "\t", state_dict['optimizer'][var_name])
            #print(self.optimizer.state_dict()) ######
            #state_dict['optimizer'][var_name] =
            self.optimizer.load_state_dict(state_dict['optimizer']) 

        if self.config['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.config['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if self.config.get('have_lr_scheduler', False):
            if self.config.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_gamma'], patience=3)
            elif self.config.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=self.config.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in self.config.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=self.config.get('lr_gamma'))
        else:
            self.scheduler = None

    def _setup_lossmap(self, config):
        task_def_list: List[TaskDef] = config['task_def_list']
        self.task_loss_criterion = []
        for idx, task_def in enumerate(task_def_list):
            cs = task_def.loss
            lc = LOSS_REGISTRY[cs](name='Loss func of task {}: {}'.format(idx, cs))
            self.task_loss_criterion.append(lc)

    def _setup_kd_lossmap(self, config):
        task_def_list: List[TaskDef] = config['task_def_list']
        self.kd_task_loss_criterion = []
        if config.get('mkd_opt', 0) > 0:
            for idx, task_def in enumerate(task_def_list):
                cs = task_def.kd_loss
                assert cs is not None
                lc = LOSS_REGISTRY[cs](name='KD Loss func of task {}: {}'.format(idx, cs))
                self.kd_task_loss_criterion.append(lc)

    def _setup_adv_lossmap(self, config):
        task_def_list: List[TaskDef] = config['task_def_list']
        self.adv_task_loss_criterion = []
        if config.get('adv_train', False):
            for idx, task_def in enumerate(task_def_list):
                cs = task_def.adv_loss
                assert cs is not None
                lc = LOSS_REGISTRY[cs](name='Adv Loss func of task {}: {}'.format(idx, cs))
                self.adv_task_loss_criterion.append(lc)


    def train(self):
        if self.para_swapped:
            self.para_swapped = False

    def _to_cuda(self, tensor):
        if tensor is None: return tensor

        if isinstance(tensor, list) or isinstance(tensor, tuple):
            y = [e.cuda(non_blocking=True) for e in tensor]
            for e in y:
                e.requires_grad = False
        else:
            y = tensor.cuda(non_blocking=True)
            y.requires_grad = False
        return y

    def update(self, batch_meta, batch_data, weight_alpha): 
        self.network.train()
        y = batch_data[batch_meta['label']]
        y = self._to_cuda(y) if self.config['cuda'] else y

        task_id = batch_meta['task_id']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        weight = None
        
        if self.config['itw_on']: 
            if self.config['cuda']:
                weight = torch.FloatTensor([batch_meta['weight']]).cuda(non_blocking=True)*weight_alpha
                
            else:
                weight = batch_meta['weight']*weight_alpha
                
        
        """
        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
            else:
                weight = batch_data[batch_meta['factor']]
        """

        # fw to get logits
        logits = self.mnetwork(*inputs)

        # compute loss
        loss = 0
        if self.task_loss_criterion[task_id] and (y is not None):
            loss_criterion = self.task_loss_criterion[task_id]
            if isinstance(loss_criterion, RankCeCriterion) and batch_meta['pairwise_size'] > 1:
                # reshape the logits for ranking.
                loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1, pairwise_size=batch_meta['pairwise_size'])
            else:
                loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

        # compute kd loss
        if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
            soft_labels = batch_meta['soft_label']
            soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
            kd_lc = self.kd_task_loss_criterion[task_id]
            kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
            loss = loss + kd_loss

        # adv training
        if self.config.get('adv_train', False) and self.adv_teacher:
            # task info
            task_type = batch_meta['task_def']['task_type']
            adv_inputs = [self.mnetwork, logits] + inputs + [task_type, batch_meta.get('pairwise_size', 1)]
            adv_loss = self.adv_teacher.forward(*adv_inputs)
            loss = loss + self.config['adv_alpha'] * adv_loss

        self.train_loss.update(loss.item(), batch_data[batch_meta['token_id']].size(0))
        # scale loss
        loss = loss / self.config.get('grad_accumulation_step', 1)
        if self.config['fp16']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.get('grad_accumulation_step', 1) == 0:
            if self.config['global_grad_clipping'] > 0:
                if self.config['fp16']:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                                   self.config['global_grad_clipping'])
                else:
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                                  self.config['global_grad_clipping'])
            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()

    def encode(self, batch_meta, batch_data):
        self.network.eval()
        inputs = batch_data[:3]
        sequence_output = self.network.encode(*inputs)[0]
        return sequence_output

    # TODO: similar as function extract, preserve since it is used by extractor.py
    # will remove after migrating to transformers package
    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_def = TaskDef.from_dict(batch_meta['task_def'])
        task_type = task_def.task_type
        task_obj = tasks.get_task_obj(task_def)
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if task_obj is not None:
            score, predict = task_obj.test_predict(score)
        elif task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        elif task_type == TaskType.SeqenceLabeling:
            mask = batch_data[batch_meta['mask']]
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
            valied_lenght = mask.sum(1).tolist()
            final_predict = []
            for idx, p in enumerate(predict):
                final_predict.append(p[: valied_lenght[idx]])
            score = score.reshape(-1).tolist()
            return score, final_predict, batch_meta['label']
        elif task_type == TaskType.Span:
            start, end = score
            predictions = []
            if self.config['encoder_type'] == EncoderModelType.BERT:
                import experiments.squad.squad_utils as mrc_utils
                scores, predictions = mrc_utils.extract_answer(batch_meta, batch_data, start, end, self.config.get('max_answer_len', 5), do_lower_case=self.config.get('do_lower_case', False))
            return scores, predictions, batch_meta['answer']
        else:
            raise ValueError("Unknown task_type: %s" % task_type)
        return score, predict, batch_meta['label']

    def save(self, filename):
        network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def load(self, checkpoint):
        model_state_dict = torch.load(checkpoint)
        if 'state' in model_state_dict:
            self.network.load_state_dict(model_state_dict['state'], strict=False)
        if 'optimizer' in model_state_dict:
            self.optimizer.load_state_dict(model_state_dict['optimizer'])
        if 'config' in model_state_dict:
            self.config.update(model_state_dict['config'])

    def cuda(self):
        self.network.cuda()
Пример #9
0
class MTDNNModel(object):
    def __init__(self, opt, device=None, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = (state_dict["updates"]
                        if state_dict and "updates" in state_dict else 0)
        self.local_updates = 0
        self.device = device
        self.train_loss = AverageMeter()
        self.adv_loss = AverageMeter()
        self.emb_val = AverageMeter()
        self.eff_perturb = AverageMeter()
        self.initial_from_local = True if state_dict else False
        model = SANBertNetwork(opt, initial_from_local=self.initial_from_local)
        self.total_param = sum(
            [p.nelement() for p in model.parameters() if p.requires_grad])
        if opt["cuda"]:
            if self.config["local_rank"] != -1:
                model = model.to(self.device)
            else:
                model = model.to(self.device)
        self.network = model
        if state_dict:
            missing_keys, unexpected_keys = self.network.load_state_dict(
                state_dict["state"], strict=False)

        optimizer_parameters = self._get_param_groups()
        self._setup_optim(optimizer_parameters, state_dict, num_train_step)
        self.optimizer.zero_grad()

        # if self.config["local_rank"] not in [-1, 0]:
        #    torch.distributed.barrier()

        if self.config["local_rank"] != -1:
            self.mnetwork = torch.nn.parallel.DistributedDataParallel(
                self.network,
                device_ids=[self.config["local_rank"]],
                output_device=self.config["local_rank"],
                find_unused_parameters=True,
            )
        elif self.config["multi_gpu_on"]:
            self.mnetwork = nn.DataParallel(self.network)
        else:
            self.mnetwork = self.network
        self._setup_lossmap(self.config)
        self._setup_kd_lossmap(self.config)
        self._setup_adv_lossmap(self.config)
        self._setup_adv_training(self.config)
        self._setup_tokenizer()

    def _setup_adv_training(self, config):
        self.adv_teacher = None
        if config.get("adv_train", False):
            self.adv_teacher = SmartPerturbation(
                config["adv_epsilon"],
                config["multi_gpu_on"],
                config["adv_step_size"],
                config["adv_noise_var"],
                config["adv_p_norm"],
                config["adv_k"],
                config["fp16"],
                config["encoder_type"],
                loss_map=self.adv_task_loss_criterion,
                norm_level=config["adv_norm_level"],
            )

    def _get_param_groups(self):
        no_decay = [
            "bias", "gamma", "beta", "LayerNorm.bias", "LayerNorm.weight"
        ]
        optimizer_parameters = [
            {
                "params": [
                    p for n, p in self.network.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.01,
            },
            {
                "params": [
                    p for n, p in self.network.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        return optimizer_parameters

    def _setup_optim(self,
                     optimizer_parameters,
                     state_dict=None,
                     num_train_step=-1):
        if self.config['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(
                optimizer_parameters,
                self.config['learning_rate'],
                weight_decay=self.config['weight_decay'])
        elif self.config['optimizer'] == 'adamax':
            self.optimizer = AdamaxW(optimizer_parameters,
                                     lr=self.config['learning_rate'],
                                     weight_decay=self.config['weight_decay'])
        elif self.config['optimizer'] == 'adam':
            self.optimizer = optim.AdamW(
                optimizer_parameters,
                lr=self.config['learning_rate'],
                weight_decay=self.config['weight_decay'])
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if state_dict and "optimizer" in state_dict:
            self.optimizer.load_state_dict(state_dict["optimizer"])

        if self.config["fp16"]:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                self.network,
                self.optimizer,
                opt_level=self.config["fp16_opt_level"])
            self.network = model
            self.optimizer = optimizer

        # # set up scheduler
        self.scheduler = None
        scheduler_type = self.config['scheduler_type']
        warmup_steps = self.config['warmup'] * num_train_step
        if scheduler_type == 3:
            from transformers import get_polynomial_decay_schedule_with_warmup
            self.scheduler = get_polynomial_decay_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_step)
        if scheduler_type == 2:
            from transformers import get_constant_schedule_with_warmup
            self.scheduler = get_constant_schedule_with_warmup(
                self.optimizer, num_warmup_steps=warmup_steps)
        elif scheduler_type == 1:
            from transformers import get_cosine_schedule_with_warmup
            self.scheduler = get_cosine_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_step)
        else:
            from transformers import get_linear_schedule_with_warmup
            self.scheduler = get_linear_schedule_with_warmup(
                self.optimizer,
                num_warmup_steps=warmup_steps,
                num_training_steps=num_train_step)

    def _setup_lossmap(self, config):
        task_def_list: List[TaskDef] = config["task_def_list"]
        self.task_loss_criterion = []
        for idx, task_def in enumerate(task_def_list):
            cs = task_def.loss
            lc = LOSS_REGISTRY[cs](
                name="Loss func of task {}: {}".format(idx, cs))
            self.task_loss_criterion.append(lc)

    def _setup_kd_lossmap(self, config):
        task_def_list: List[TaskDef] = config["task_def_list"]
        self.kd_task_loss_criterion = []
        if config.get("mkd_opt", 0) > 0:
            for idx, task_def in enumerate(task_def_list):
                cs = task_def.kd_loss
                assert cs is not None
                lc = LOSS_REGISTRY[cs](
                    name="KD Loss func of task {}: {}".format(idx, cs))
                self.kd_task_loss_criterion.append(lc)

    def _setup_adv_lossmap(self, config):
        task_def_list: List[TaskDef] = config["task_def_list"]
        self.adv_task_loss_criterion = []
        if config.get("adv_train", False):
            for idx, task_def in enumerate(task_def_list):
                cs = task_def.adv_loss
                assert cs is not None
                lc = LOSS_REGISTRY[cs](
                    name="Adv Loss func of task {}: {}".format(idx, cs))
                self.adv_task_loss_criterion.append(lc)

    def _setup_tokenizer(self):
        try:
            from transformers import AutoTokenizer

            self.tokenizer = AutoTokenizer.from_pretrained(
                self.config["init_checkpoint"],
                cache_dir=self.config["transformer_cache"],
            )
        except:
            self.tokenizer = None

    def _to_cuda(self, tensor):
        if tensor is None:
            return tensor

        if isinstance(tensor, list) or isinstance(tensor, tuple):
            # y = [e.cuda(non_blocking=True) for e in tensor]
            y = [e.to(self.device) for e in tensor]
            for e in y:
                e.requires_grad = False
        else:
            # y = tensor.cuda(non_blocking=True)
            y = tensor.to(self.device)
            y.requires_grad = False
        return y

    def update(self, batch_meta, batch_data):
        self.network.train()
        y = batch_data[batch_meta["label"]]
        y = self._to_cuda(y) if self.config["cuda"] else y
        if batch_meta["task_def"]["task_type"] == TaskType.SeqenceGeneration:
            seq_length = y.size(1)
            y = y.view(-1)

        task_id = batch_meta["task_id"]
        inputs = batch_data[:batch_meta["input_len"]]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        if "y_token_id" in batch_meta:
            inputs.append(batch_data[batch_meta["y_token_id"]])
        weight = None
        if self.config.get("weighted_on", False):
            if self.config["cuda"]:
                weight = batch_data[batch_meta["factor"]].cuda(
                    non_blocking=True)
            else:
                weight = batch_data[batch_meta["factor"]]

        # fw to get logits
        logits = self.mnetwork(*inputs)

        # compute loss
        loss = 0
        if self.task_loss_criterion[task_id] and (y is not None):
            loss_criterion = self.task_loss_criterion[task_id]
            if (isinstance(loss_criterion, RankCeCriterion)
                    and batch_meta["pairwise_size"] > 1):
                # reshape the logits for ranking.
                loss = self.task_loss_criterion[task_id](
                    logits,
                    y,
                    weight,
                    ignore_index=-1,
                    pairwise_size=batch_meta["pairwise_size"],
                )
            elif batch_meta["task_def"][
                    "task_type"] == TaskType.SeqenceGeneration:
                weight = ((1.0 / torch.sum(
                    (y > -1).float().view(-1, seq_length), 1,
                    keepdim=True)).repeat(1, seq_length).view(-1))
                loss = self.task_loss_criterion[task_id](logits,
                                                         y,
                                                         weight,
                                                         ignore_index=-1)
            else:
                loss = self.task_loss_criterion[task_id](logits,
                                                         y,
                                                         weight,
                                                         ignore_index=-1)

        # compute kd loss
        if self.config.get("mkd_opt", 0) > 0 and ("soft_label" in batch_meta):
            soft_labels = batch_meta["soft_label"]
            soft_labels = (self._to_cuda(soft_labels)
                           if self.config["cuda"] else soft_labels)
            kd_lc = self.kd_task_loss_criterion[task_id]
            kd_loss = (kd_lc(logits, soft_labels, weight, ignore_index=-1)
                       if kd_lc else 0)
            loss = loss + kd_loss

        # adv training
        if self.config.get("adv_train", False) and self.adv_teacher:
            # task info
            task_type = batch_meta["task_def"]["task_type"]
            adv_inputs = (
                [self.mnetwork, logits] + inputs +
                [task_type, batch_meta.get("pairwise_size", 1)])
            adv_loss, emb_val, eff_perturb = self.adv_teacher.forward(
                *adv_inputs)
            loss = loss + self.config["adv_alpha"] * adv_loss

        batch_size = batch_data[batch_meta["token_id"]].size(0)
        # rescale loss as dynamic batching
        if self.config["bin_on"]:
            loss = loss * (1.0 * batch_size / self.config["batch_size"])
        if self.config["local_rank"] != -1:
            # print('Rank ', self.config['local_rank'], ' loss ', loss)
            copied_loss = copy.deepcopy(loss.data)
            torch.distributed.all_reduce(copied_loss)
            copied_loss = copied_loss / self.config["world_size"]
            self.train_loss.update(copied_loss.item(), batch_size)
        else:
            self.train_loss.update(loss.item(), batch_size)

        if self.config.get("adv_train", False) and self.adv_teacher:
            if self.config["local_rank"] != -1:
                copied_adv_loss = copy.deepcopy(adv_loss.data)
                torch.distributed.all_reduce(copied_adv_loss)
                copied_adv_loss = copied_adv_loss / self.config["world_size"]
                self.adv_loss.update(copied_adv_loss.item(), batch_size)

                copied_emb_val = copy.deepcopy(emb_val.data)
                torch.distributed.all_reduce(copied_emb_val)
                copied_emb_val = copied_emb_val / self.config["world_size"]
                self.emb_val.update(copied_emb_val.item(), batch_size)

                copied_eff_perturb = copy.deepcopy(eff_perturb.data)
                torch.distributed.all_reduce(copied_eff_perturb)
                copied_eff_perturb = copied_eff_perturb / self.config[
                    "world_size"]
                self.eff_perturb.update(copied_eff_perturb.item(), batch_size)
            else:
                self.adv_loss.update(adv_loss.item(), batch_size)
                self.emb_val.update(emb_val.item(), batch_size)
                self.eff_perturb.update(eff_perturb.item(), batch_size)

        # scale loss
        loss = loss / self.config.get("grad_accumulation_step", 1)
        if self.config["fp16"]:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.get("grad_accumulation_step",
                                                1) == 0:
            if self.config["global_grad_clipping"] > 0:
                if self.config["fp16"]:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer),
                        self.config["global_grad_clipping"],
                    )
                else:
                    torch.nn.utils.clip_grad_norm_(
                        self.network.parameters(),
                        self.config["global_grad_clipping"])
            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()
            if self.scheduler:
                self.scheduler.step()

    def encode(self, batch_meta, batch_data):
        self.network.eval()
        inputs = batch_data[:3]
        sequence_output = self.network.encode(*inputs)[0]
        return sequence_output

    # TODO: similar as function extract, preserve since it is used by extractor.py
    # will remove after migrating to transformers package
    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta["task_id"]
        task_def = TaskDef.from_dict(batch_meta["task_def"])
        task_type = task_def.task_type
        task_obj = tasks.get_task_obj(task_def)
        inputs = batch_data[:batch_meta["input_len"]]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        if task_type == TaskType.SeqenceGeneration:
            # y_idx, #3 -> gen
            inputs.append(None)
            inputs.append(3)

        score = self.mnetwork(*inputs)
        if task_obj is not None:
            score, predict = task_obj.test_predict(score)
        elif task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta["pairwise_size"])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta["true_label"]
        elif task_type == TaskType.SeqenceLabeling:
            mask = batch_data[batch_meta["mask"]]
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
            valied_lenght = mask.sum(1).tolist()
            final_predict = []
            for idx, p in enumerate(predict):
                final_predict.append(p[:valied_lenght[idx]])
            score = score.reshape(-1).tolist()
            return score, final_predict, batch_meta["label"]
        elif task_type == TaskType.Span or task_type == TaskType.SpanYN:
            predictions = []
            features = []
            for idx, offset in enumerate(batch_meta["offset_mapping"]):
                token_is_max_context = (
                    batch_meta["token_is_max_context"][idx] if batch_meta.get(
                        "token_is_max_context", None) else None)
                sample_id = batch_meta["uids"][idx]
                if "label" in batch_meta:
                    feature = {
                        "offset_mapping": offset,
                        "token_is_max_context": token_is_max_context,
                        "uid": sample_id,
                        "context": batch_meta["context"][idx],
                        "answer": batch_meta["answer"][idx],
                        "label": batch_meta["label"][idx],
                    }
                else:
                    feature = {
                        "offset_mapping": offset,
                        "token_is_max_context": token_is_max_context,
                        "uid": sample_id,
                        "context": batch_meta["context"][idx],
                        "answer": batch_meta["answer"][idx],
                    }
                if "null_ans_index" in batch_meta:
                    feature["null_ans_index"] = batch_meta["null_ans_index"]
                features.append(feature)
            start, end = score
            start = start.contiguous()
            start = start.data.cpu()
            start = start.numpy().tolist()
            end = end.contiguous()
            end = end.data.cpu()
            end = end.numpy().tolist()
            return (start, end), predictions, features
        elif task_type == TaskType.SeqenceGeneration:
            predicts = self.tokenizer.batch_decode(score,
                                                   skip_special_tokens=True)
            predictions = {}
            golds = {}
            for idx, predict in enumerate(predicts):
                sample_id = batch_meta["uids"][idx]
                answer = batch_meta["answer"][idx]
                predict = predict.strip()
                if predict == DUMPY_STRING_FOR_EMPTY_ANS:
                    predict = ""
                predictions[sample_id] = predict
                golds[sample_id] = answer
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy().tolist()
            return score, predictions, golds
        elif task_type == TaskType.ClozeChoice:
            score = score.contiguous().view(-1)
            score = score.data.cpu()
            score = score.numpy()
            copy_score = score.tolist()
            answers = batch_meta["answer"]
            choices = batch_meta["choice"]
            chunks = batch_meta["pairwise_size"]
            uids = batch_meta["uids"]
            predictions = {}
            golds = {}
            for chunk in chunks:
                uid = uids[0]
                answer = eval(answers[0])
                choice = eval(choices[0])
                answers = answers[chunk:]
                choices = choices[chunk:]
                current_p = score[:chunk]
                score = score[chunk:]
                positive = np.argmax(current_p)
                predict = choice[positive]
                predictions[uid] = predict
                golds[uid] = answer
            return copy_score, predictions, golds
        else:
            raise ValueError("Unknown task_type: %s" % task_type)
        return score, predict, batch_meta["label"]

    def save(self, filename):
        if isinstance(self.mnetwork,
                      torch.nn.parallel.DistributedDataParallel):
            model = self.mnetwork.module
        else:
            model = self.network
        # network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        network_state = dict([(k, v.cpu())
                              for k, v in model.state_dict().items()])
        params = {
            "state": network_state,
            "optimizer": self.optimizer.state_dict(),
            "config": self.config,
        }
        torch.save(params, filename)
        logger.info("model saved to {}".format(filename))

    def load(self, checkpoint):
        model_state_dict = torch.load(checkpoint)
        if "state" in model_state_dict:
            self.network.load_state_dict(model_state_dict["state"],
                                         strict=False)
        if "optimizer" in model_state_dict:
            self.optimizer.load_state_dict(model_state_dict["optimizer"])
        if "config" in model_state_dict:
            self.config.update(model_state_dict["config"])

    def cuda(self):
        self.network.cuda()