Exemple #1
0
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict[
            'updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)

        if state_dict:
            self.network.load_state_dict(state_dict['state'], strict=False)
        self.mnetwork = nn.DataParallel(
            self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([
            p.nelement() for p in self.network.parameters() if p.requires_grad
        ])
        if opt['cuda']:
            self.network.cuda()

        no_decay = [
            'bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight'
        ]

        optimizer_parameters = [{
            'params': [
                p for n, p in self.network.named_parameters()
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params': [
                p for n, p in self.network.named_parameters()
                if any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.0
        }]

        # note that adamax are modified based on the BERT code
        if opt['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters,
                                       opt['learning_rate'],
                                       weight_decay=opt['weight_decay'])

        elif opt['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    opt['learning_rate'],
                                    warmup=opt['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=opt['grad_clipping'],
                                    schedule=opt['warmup_schedule'],
                                    weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
        elif opt['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                   opt['learning_rate'],
                                   warmup=opt['warmup'],
                                   t_total=num_train_step,
                                   max_grad_norm=opt['grad_clipping'],
                                   schedule=opt['warmup_schedule'],
                                   eps=opt['adam_eps'],
                                   weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            opt['fp16'] = False
        elif opt['optimizer'] == 'adadelta':
            self.optimizer = optim.Adadelta(optimizer_parameters,
                                            opt['learning_rate'],
                                            rho=0.95)
        elif opt['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=opt['learning_rate'],
                                  warmup=opt['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=opt['grad_clipping'],
                                  schedule=opt['warmup_schedule'],
                                  weight_decay=opt['weight_decay'])
            if opt.get('have_lr_scheduler', False):
                opt['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if opt['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(self.network,
                                              self.optimizer,
                                              opt_level=opt['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if opt.get('have_lr_scheduler', False):
            if opt.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer,
                                                   mode='max',
                                                   factor=opt['lr_gamma'],
                                                   patience=3)
            elif opt.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer,
                                               gamma=opt.get('lr_gamma', 0.95))
            else:
                milestones = [
                    int(step)
                    for step in opt.get('multi_step_lr', '10,20,30').split(',')
                ]
                self.scheduler = MultiStepLR(self.optimizer,
                                             milestones=milestones,
                                             gamma=opt.get('lr_gamma'))
        else:
            self.scheduler = None

        self.ema = None
        if opt['ema_opt'] > 0:
            self.ema = EMA(self.config['ema_gamma'], self.network)
            if opt['cuda']:
                self.ema.cuda()

        self.para_swapped = False
        # zero optimizer grad
        self.optimizer.zero_grad()

    def setup_ema(self):
        if self.config['ema_opt']:
            self.ema.setup()

    def update_ema(self):
        if self.config['ema_opt']:
            self.ema.update()

    def eval(self):
        if self.config['ema_opt']:
            self.ema.swap_parameters()
            self.para_swapped = True

    def train(self):
        if self.para_swapped:
            self.ema.swap_parameters()
            self.para_swapped = False

    def update(self, batch_meta, batch_data):
        self.network.train()
        labels = batch_data[batch_meta['label']]
        soft_labels = None
        if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
            soft_labels = batch_meta['soft_label']

        task_type = batch_meta['task_type']
        if task_type == TaskType.Span:
            start = batch_data[batch_meta['start']]
            end = batch_data[batch_meta['end']]
            if self.config["cuda"]:
                start = start.cuda(non_blocking=True)
                end = end.cuda(non_blocking=True)
            start.requires_grad = False
            end.requires_grad = False
        else:
            y = labels
            if task_type == TaskType.Ranking:
                y = y.contiguous().view(-1, batch_meta['pairwise_size'])[:, 0]
            if self.config['cuda']:
                y = y.cuda(non_blocking=True)
            y.requires_grad = False

        task_id = batch_meta['task_id']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)

        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = batch_data[batch_meta['factor']].cuda(
                    non_blocking=True)
            else:
                weight = batch_data[batch_meta['factor']]

        if task_type == TaskType.Span:
            start_logits, end_logits = self.mnetwork(*inputs)
            ignored_index = start_logits.size(1)
            start.clamp_(0, ignored_index)
            end.clamp_(0, ignored_index)
            if self.config.get('weighted_on', False):
                loss = torch.mean(F.cross_entropy(start_logits, start, reduce=False) * weight) + \
                    torch.mean(F.cross_entropy(end_logits, end, reduce=False) * weight)
            else:
                loss = F.cross_entropy(start_logits, start, ignore_index=ignored_index) + \
                    F.cross_entropy(end_logits, end, ignore_index=ignored_index)
            loss = loss / 2
        else:
            logits = self.mnetwork(*inputs)
            if task_type == TaskType.Ranking:
                logits = logits.view(-1, batch_meta['pairwise_size'])
            if self.config.get('weighted_on', False):
                if task_type == TaskType.Regression:
                    loss = torch.mean(
                        F.mse_loss(logits.squeeze(), y, reduce=False) * weight)
                else:
                    loss = torch.mean(
                        F.cross_entropy(logits, y, reduce=False) * weight)
                    if soft_labels is not None:
                        # compute KL
                        label_size = soft_labels.size(1)
                        kd_loss = F.kl_div(F.log_softmax(
                            logits.view(-1, label_size).float(), 1),
                                           soft_labels,
                                           reduction='batchmean')
                        loss = loss + kd_loss
            else:
                if task_type == TaskType.Regression:
                    loss = F.mse_loss(logits.squeeze(), y)
                else:
                    loss = F.cross_entropy(logits, y)
                    if soft_labels is not None:
                        # compute KL
                        label_size = soft_labels.size(1)
                        kd_loss = F.kl_div(F.log_softmax(
                            logits.view(-1, label_size).float(), 1),
                                           soft_labels,
                                           reduction='batchmean')
                        loss = loss + kd_loss

        self.train_loss.update(loss.item(), logits.size(0))
        # scale loss
        loss = loss / self.config.get('grad_accumulation_step', 1)
        if self.config['fp16']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.get('grad_accumulation_step',
                                                1) == 0:
            if self.config['global_grad_clipping'] > 0:
                if self.config['fp16']:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer),
                        self.config['global_grad_clipping'])
                else:
                    torch.nn.utils.clip_grad_norm_(
                        self.network.parameters(),
                        self.config['global_grad_clipping'])

            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.update_ema()

    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        else:
            if task_type == TaskType.Classification:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta['label']

    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def save(self, filename):
        network_state = dict([(k, v.cpu())
                              for k, v in self.network.state_dict().items()])
        ema_state = dict([
            (k, v.cpu()) for k, v in self.ema.model.state_dict().items()
        ]) if self.ema is not None else dict()
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'ema': ema_state,
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def load(self, checkpoint):

        model_state_dict = torch.load(checkpoint)
        if model_state_dict['config']['init_checkpoint'].rsplit('/', 1)[1] != \
                self.config['init_checkpoint'].rsplit('/', 1)[1]:
            logger.error(
                '*** SANBert network is pretrained on a different Bert Model. Please use that to fine-tune for other tasks. ***'
            )
            sys.exit()

        self.network.load_state_dict(model_state_dict['state'], strict=False)
        self.optimizer.load_state_dict(model_state_dict['optimizer'])
        self.config = model_state_dict['config']
        if self.ema:
            self.ema.model.load_state_dict(model_state_dict['ema'])

    def cuda(self):
        self.network.cuda()
        if self.config['ema_opt']:
            self.ema.cuda()
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict['updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)
        if state_dict:
            self.network.load_state_dict(state_dict['state'], strict=False)
        self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])
        if opt['cuda']:
            self.network.cuda()
        optimizer_parameters = self._get_param_groups()
        self._setup_optim(optimizer_parameters, state_dict, num_train_step)
        self.para_swapped = False
        self.optimizer.zero_grad()
        self._setup_lossmap(self.config)
        self.sampler = None

    def _get_param_groups(self):
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in self.network.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        return optimizer_parameters

    def _setup_optim(self, optimizer_parameters, state_dict=None, num_train_step=-1):
        if self.config['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters, self.config['learning_rate'],
                                       weight_decay=self.config['weight_decay'])

        elif self.config['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        elif self.config['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    eps=self.config['adam_eps'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            self.config['fp16'] = False
        elif self.config['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=self.config['learning_rate'],
                                  warmup=self.config['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=self.config['grad_clipping'],
                                  schedule=self.config['warmup_schedule'],
                                  weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if self.config['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.config['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if self.config.get('have_lr_scheduler', False):
            if self.config.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_gamma'], patience=3)
            elif self.config.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=self.config.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in self.config.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=self.config.get('lr_gamma'))
        else:
            self.scheduler = None

    def _setup_lossmap(self, config):
        loss_types = config['loss_types']
        self.task_loss_criterion = []
        for idx, cs in enumerate(loss_types):
            assert cs is not None
            lc = LOSS_REGISTRY[cs](name='Loss func of task {}: {}'.format(idx, cs))
            self.task_loss_criterion.append(lc)

    def _setup_kd_lossmap(self, config):
        loss_types = config['kd_loss_types']
        self.kd_task_loss_criterion = []
        if config.get('mkd_opt', 0) > 0:
            for idx, cs in enumerate(loss_types):
                assert cs is not None
                lc = LOSS_REGISTRY[cs](name='Loss func of task {}: {}'.format(idx, cs))
                self.kd_task_loss_criterion.append(lc)

    def train(self):
        if self.para_swapped:
            self.para_swapped = False

    def _to_cuda(self, tensor):
        if tensor is None: return tensor

        if isinstance(tensor, list) or isinstance(tensor, tuple):
            y = [e.cuda(non_blocking=True) for e in tensor]
            for e in y:
                e.requires_grad = False
        else:
            y = tensor.cuda(non_blocking=True)
            y.requires_grad = False
        return y

    def update(self, batch_meta, batch_data):
        self.network.train()
        y = batch_data[batch_meta['label']]
        soft_labels = None

        task_type = batch_meta['task_type']
        y = self._to_cuda(y) if self.config['cuda'] else y

        task_id = batch_meta['task_id']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        weight = None
        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
            else:
                weight = batch_data[batch_meta['factor']]
        logits = self.mnetwork(*inputs)

        # compute loss
        loss = 0
        if self.task_loss_criterion[task_id] and (y is not None):
            loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

        # compute kd loss
        if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
            soft_labels = batch_meta['soft_label']
            soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
            kd_lc = self.kd_task_loss_criterion[task_id]
            kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
            loss = loss + kd_loss

        self.train_loss.update(loss.item(), batch_data[batch_meta['token_id']].size(0))
        # scale loss
        loss = loss / self.config.get('grad_accumulation_step', 1)
        if self.sampler:
            self.sampler.update_loss(task_id, loss.item())
        if self.config['fp16']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.get('grad_accumulation_step', 1) == 0:
            if self.config['global_grad_clipping'] > 0:
                if self.config['fp16']:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                                   self.config['global_grad_clipping'])
                else:
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                                  self.config['global_grad_clipping'])
            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()

    def update_all(self, batch_metas, batch_datas):
        self.network.train()
        task_losses = []
        max_loss = 0
        max_task = 0
        with torch.no_grad():
            for i, (batch_meta, batch_data) in enumerate(zip(batch_metas, batch_datas)):
                y = batch_data[batch_meta['label']]
                soft_labels = None

                task_type = batch_meta['task_type']
                y = self._to_cuda(y) if self.config['cuda'] else y

                task_id = batch_meta['task_id']
                assert task_id == i
                inputs = batch_data[:batch_meta['input_len']]
                if len(inputs) == 3:
                    inputs.append(None)
                    inputs.append(None)
                inputs.append(task_id)
                weight = None
                if self.config.get('weighted_on', False):
                    if self.config['cuda']:
                        weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
                    else:
                        weight = batch_data[batch_meta['factor']]
                logits = self.mnetwork(*inputs)

                # compute loss
                loss = 0
                if self.task_loss_criterion[task_id] and (y is not None):
                    loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

                # compute kd loss
                if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
                    soft_labels = batch_meta['soft_label']
                    soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
                    kd_lc = self.kd_task_loss_criterion[task_id]
                    kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
                    loss = loss + kd_loss
                
                if loss > max_loss:
                    max_loss = loss
                    max_task = i
                task_losses.append(loss.item())
        # print(task_losses)
        p = np.array(task_losses)
        p /= p.sum()
        sampled_id = 0
        if random.random() < 0.3:
            sampled_id = max_task
        else:
            sampled_id = np.random.choice(list(range(len(task_losses))), p=p, replace=False)
        # print("task ", sampled_id, " is selected")
        self.loss_list = task_losses
        self.update(batch_metas[sampled_id], batch_datas[sampled_id])

    def calculate_loss(self, batch_meta, batch_data):
        self.network.train()
        with torch.no_grad():
            y = batch_data[batch_meta['label']]
            soft_labels = None

            task_type = batch_meta['task_type']
            y = self._to_cuda(y) if self.config['cuda'] else y

            task_id = batch_meta['task_id']
            
            inputs = batch_data[:batch_meta['input_len']]
            if len(inputs) == 3:
                inputs.append(None)
                inputs.append(None)
            inputs.append(task_id)
            weight = None
            if self.config.get('weighted_on', False):
                if self.config['cuda']:
                    weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
                else:
                    weight = batch_data[batch_meta['factor']]
            logits = self.mnetwork(*inputs)

            # compute loss
            loss = 0
            if self.task_loss_criterion[task_id] and (y is not None):
                loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

            # compute kd loss
            if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
                soft_labels = batch_meta['soft_label']
                soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
                kd_lc = self.kd_task_loss_criterion[task_id]
                kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
                loss = loss + kd_loss
        return loss   


    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        elif task_type == TaskType.SeqenceLabeling:
            mask = batch_data[batch_meta['mask']]
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
            valied_lenght = mask.sum(1).tolist()
            final_predict = []
            for idx, p in enumerate(predict):
                final_predict.append(p[: valied_lenght[idx]])
            score = score.reshape(-1).tolist()
            return score, final_predict, batch_meta['label']
        elif task_type == TaskType.Span:
            start, end = score
            predictions = []
            if self.config['encoder_type'] == EncoderModelType.BERT:
                import experiments.squad.squad_utils as mrc_utils
                scores, predictions = mrc_utils.extract_answer(batch_meta, batch_data,start, end, self.config.get('max_answer_len', 5))
            return scores, predictions, batch_meta['answer']
        else:
            if task_type == TaskType.Classification:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta['label']

    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def save(self, filename):
        network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def load(self, checkpoint):
        model_state_dict = torch.load(checkpoint)
        self.network.load_state_dict(model_state_dict['state'], strict=False)
        self.optimizer.load_state_dict(model_state_dict['optimizer'])
        self.config.update(model_state_dict['config'])

    def cuda(self):
        self.network.cuda()

        
Exemple #3
0
def main():
    train_df = pd.read_csv(TRAIN_PATH)
    fold_df = pd.read_csv(FOLD_PATH)
    n_train_df = len(train_df)

    old_folds = pd.read_csv(FOLD_PATH_JIGSAW)

    old_df = pd.read_csv(OLD_PATH)
    old_df["target"] = old_df[["toxic", "severe_toxic", "obscene", "threat", "insult", "identity_hate"]].sum(axis=1)
    old_df["target"] = (old_df["target"] >= 1).astype("int8")
    old_df = old_df[old_folds.fold_id != fold_id]
    train_df = train_df.append(old_df).reset_index(drop=True)
    del old_folds, old_df
    gc.collect()

    # y = np.where(train_df['target'] >= 0.5, 1, 0)
    y = train_df['target'].values

    identity_columns_new = []
    for column in identity_columns + ['target']:
        train_df[column + "_bin"] = np.where(train_df[column] >= 0.5, True, False)
        if column != "target":
            identity_columns_new.append(column + "_bin")

    # Overall
    #weights = np.ones((len(train_df),)) / 4
    # Subgroup
    #weights += (train_df[identity_columns].fillna(0).values >= 0.5).sum(axis=1).astype(bool).astype(np.int) / 4
    # Background Positive, Subgroup Negative
    #weights += (((train_df["target"].values >= 0.5).astype(bool).astype(np.int) +
    #             (1 - (train_df[identity_columns].fillna(0).values >= 0.5).sum(axis=1).astype(bool).astype(
    #                 np.int))) > 1).astype(bool).astype(np.int) / 4
    # Background Negative, Subgroup Positive
    #weights += (((train_df["target"].values < 0.5).astype(bool).astype(np.int) +
    #             (train_df[identity_columns].fillna(0).values >= 0.5).sum(axis=1).astype(bool).astype(
    #                 np.int)) > 1).astype(bool).astype(np.int) / 4
    #loss_weight = 0.5

    with timer('preprocessing text'):
        # df["comment_text"] = [analyzer_embed(text) for text in df["comment_text"]]
        train_df['comment_text'] = train_df['comment_text'].astype(str)
        train_df = train_df.fillna(0)

    with timer('load embedding'):
        tokenizer = BertTokenizer.from_pretrained(BERT_MODEL_PATH, cache_dir=None, do_lower_case=True)
        X_text, train_lengths = convert_lines(train_df["comment_text"].fillna("DUMMY_VALUE"), max_len, tokenizer)
        del train_lengths, tokenizer
        gc.collect()

    LOGGER.info(f"X_text {X_text.shape}")

    X_old = X_text[n_train_df:].astype("int32")
    X_text = X_text[:n_train_df].astype("int32")
    #w_trans = weights[n_train_df:].astype("float32")
    #weights = weights[:n_train_df].astype("float32")
    y_old = y[n_train_df:].astype("float32")
    y = y[:n_train_df].astype("float32")
    train_df = train_df[:n_train_df]

    with timer('train'):
        train_index = fold_df.fold_id != fold_id
        valid_index = fold_df.fold_id == fold_id
        X_train, y_train = X_text[train_index].astype("int32"), y[train_index].astype("float32")
        X_val, y_val = X_text[valid_index].astype("int32"), y[valid_index].astype("float32")
        test_df = train_df[valid_index]
        del X_text, y, train_index, valid_index, train_df
        gc.collect()

        model = BertForSequenceClassification.from_pretrained(WORK_DIR, cache_dir=None, num_labels=n_labels)
        model.zero_grad()
        model = model.to(device)

        X_train = np.concatenate([X_train, X_old], axis=0)
        y_train = np.concatenate([y_train, y_old], axis=0)
        train_size = len(X_train)
        del X_old, y_old
        gc.collect()

        train_dataset = torch.utils.data.TensorDataset(torch.tensor(X_train, dtype=torch.long),
                                                       torch.tensor(y_train, dtype=torch.float32))
        valid = torch.utils.data.TensorDataset(torch.tensor(X_val, dtype=torch.long),
                                               torch.tensor(y_val, dtype=torch.float32))
        ran_sampler = torch.utils.data.RandomSampler(train_dataset)
        len_sampler = LenMatchBatchSampler(ran_sampler, batch_size=batch_size, drop_last=False)
        train_loader = torch.utils.data.DataLoader(train_dataset, batch_sampler=len_sampler)
        valid_loader = torch.utils.data.DataLoader(valid, batch_size=batch_size * 2, shuffle=False)
        del X_train, y_train, X_val, y_val
        gc.collect()
        LOGGER.info(f"done data loader setup")

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

        num_train_optimization_steps = int(epochs * train_size / batch_size / accumulation_steps)
        total_step = int(epochs * train_size / batch_size)

        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=base_lr,
                             warmup=0.005,
                             t_total=num_train_optimization_steps)
        LOGGER.info(f"done optimizer loader setup")

        model, optimizer = amp.initialize(model, optimizer, opt_level="O1", verbosity=0)
        criterion = torch.nn.BCEWithLogitsLoss().to(device)
        #criterion = CustomLoss(loss_weight).to(device)
        LOGGER.info(f"done amp setup")

        for epoch in range(epochs):
            LOGGER.info(f"Starting {epoch} epoch...")
            LOGGER.info(f"length {train_size} train...")
            if epoch == 1:
                for param_group in optimizer.param_groups:
                    param_group['lr'] = base_lr * gammas[1]
            tr_loss, train_losses = train_one_epoch(model, train_loader, criterion, optimizer, device,
                                                    accumulation_steps, total_step, n_labels, base_lr,
                                                    gamma=gammas[2 * epoch])
            LOGGER.info(f'Mean train loss: {round(tr_loss,5)}')

            torch.save(model.state_dict(), '{}_dic_epoch{}'.format(exp, epoch))
            torch.save(optimizer.state_dict(), '{}_optimizer_epoch{}.pth'.format(exp, epoch))

            valid_loss, oof_pred = validate(model, valid_loader, criterion, device, n_labels)
            LOGGER.info(f'Mean valid loss: {round(valid_loss,5)}')

            if epochs > 1:
                test_df_cp = test_df.copy()
                test_df_cp["pred"] = oof_pred[:, 0]
                test_df_cp = convert_dataframe_to_bool(test_df_cp)
                bias_metrics_df = compute_bias_metrics_for_model(test_df_cp, identity_columns)
                LOGGER.info(bias_metrics_df)

                score = get_final_metric(bias_metrics_df, calculate_overall_auc(test_df_cp))
                LOGGER.info(f'score is {score}')

        del model
        gc.collect()
        torch.cuda.empty_cache()

    test_df["pred"] = oof_pred[:, 0]
    test_df = convert_dataframe_to_bool(test_df)
    bias_metrics_df = compute_bias_metrics_for_model(test_df, identity_columns)
    LOGGER.info(bias_metrics_df)

    score = get_final_metric(bias_metrics_df, calculate_overall_auc(test_df))
    LOGGER.info(f'final score is {score}')

    test_df.to_csv("oof.csv", index=False)

    xs = list(range(1, len(train_losses) + 1))
    plt.plot(xs, train_losses, label='Train loss');
    plt.legend();
    plt.xticks(xs);
    plt.xlabel('Iter')
    plt.savefig("loss.png")
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict['updates'] if state_dict and 'updates' in state_dict else 0
        self.train_loss = AverageMeter()
        self.network = SANBertNetwork(opt)

        if state_dict:
            new_state = set(self.network.state_dict().keys())
            for k in list(state_dict['state'].keys()):
                if k not in new_state:
                    del state_dict['state'][k]
            for k, v in list(self.network.state_dict().items()):
                if k not in state_dict['state']:
                    state_dict['state'][k] = v
            self.network.load_state_dict(state_dict['state'])
        self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])

        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.named_parameters() if n not in no_decay], 'weight_decay_rate': 0.01},
            {'params': [p for n, p in self.network.named_parameters() if n in no_decay], 'weight_decay_rate': 0.0}
            ]
        # note that adamax are modified based on the BERT code
        if opt['optimizer'] == 'sgd':
            self.optimizer = optim.sgd(optimizer_parameters, opt['learning_rate'],
                                       weight_decay=opt['weight_decay'])

        elif opt['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                        opt['learning_rate'],
                                        warmup=opt['warmup'],
                                        t_total=num_train_step,
                                        max_grad_norm=opt['grad_clipping'],
                                        schedule=opt['warmup_schedule'])
            if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False
        elif opt['optimizer'] == 'adadelta':
            self.optimizer = optim.Adadelta(optimizer_parameters,
                                            opt['learning_rate'],
                                            rho=0.95)
        elif opt['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                        lr=opt['learning_rate'],
                                        warmup=opt['warmup'],
                                        t_total=num_train_step,
                                        max_grad_norm=opt['grad_clipping'],
                                        schedule=opt['warmup_schedule'])
            if opt.get('have_lr_scheduler', False): opt['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])

        if state_dict and 'optimizer' in state_dict:
            self.optimizer.load_state_dict(state_dict['optimizer'])

        if opt.get('have_lr_scheduler', False):
            if opt.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=opt['lr_gamma'], patience=3)
            elif opt.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=opt.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in opt.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=opt.get('lr_gamma'))
        else:
            self.scheduler = None
        self.ema = None
        if opt['ema_opt'] > 0:
            self.ema = EMA(self.config['ema_gamma'], self.network)
        self.para_swapped = False

    def setup_ema(self):
        if self.config['ema_opt']:
            self.ema.setup()

    def update_ema(self):
        if self.config['ema_opt']:
            self.ema.update()

    def eval(self):
        if self.config['ema_opt']:
            self.ema.swap_parameters()
            self.para_swapped = True

    def train(self):
        if self.para_swapped:
            self.ema.swap_parameters()
            self.para_swapped = False

    def update(self, batch_meta, batch_data):
        self.network.train()
        labels = batch_data[batch_meta['label']]
        if batch_meta['pairwise']:
            labels = labels.contiguous().view(-1, batch_meta['pairwise_size'])[:, 0]
        if self.config['cuda']:
            y = Variable(labels.cuda(async=True), requires_grad=False)
        else:
            y = Variable(labels, requires_grad=False)
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        logits = self.mnetwork(*inputs)
        if batch_meta['pairwise']:
            logits = logits.view(-1, batch_meta['pairwise_size'])

        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = Variable(batch_data[batch_meta['factor']].cuda(async=True))
            else:
                weight = Variable(batch_data[batch_meta['factor']])
            if task_type > 0:
                loss = torch.mean(F.mse_loss(logits.squeeze(), y, reduce=False) * weight)
            else:
                loss = torch.mean(F.cross_entropy(logits, y, reduce=False) * weight)
        else:
            if task_type > 0:
                loss = F.mse_loss(logits.squeeze(), y)
            else:
                loss = F.cross_entropy(logits, y)

        self.train_loss.update(loss.item(), logits.size(0))
        self.optimizer.zero_grad()

        loss.backward()
        if self.config['global_grad_clipping'] > 0:
            torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                          self.config['global_grad_clipping'])
        self.optimizer.step()
        self.updates += 1
        self.update_ema()

    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_type = batch_meta['task_type']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if batch_meta['pairwise']:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            if task_type < 1:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        else:
            if task_type < 1:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta['label']

    def save(self, filename):
        network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        ema_state = dict(
            [(k, v.cpu()) for k, v in self.ema.model.state_dict().items()]) if self.ema is not None else dict()
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'ema': ema_state,
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def cuda(self):
        self.network.cuda()
        if self.config['ema_opt']:
            self.ema.cuda()
Exemple #5
0
                                 torch.zeros(
                                     tgtTensor.size(0) - inputTensor.size(0),
                                     inputTensor.size(1)).cuda()))
        output = model(inputTensor.unsqueeze(0), attn_masked.unsqueeze(0))
        # loss = label_smoothing(output, tgtTensor)
        loss = focal_loss(output, tgtTensor)
        testing_loss_sum += loss.item()

    if (epoch + 1) % SAVE_EVERY == 0:
        torch.save(
            {
                'epoch': epoch + 1,
                'state': model.state_dict(),
                'full_model': model,
                'training_loss': training_loss_sum / len(training_data),
                'optimizer': optimizer.state_dict()
            }, f'checkpoint/bert-LanGen-epoch{epoch + 1}.pt')

        torch.save(
            {
                'epoch': epoch + 1,
                'state': model.state_dict(),
                'full_model': model,
                'training_loss': training_loss_sum / len(training_data),
                'optimizer': optimizer.state_dict()
            }, f'checkpoint/bert-LanGen-last.pt')

    log = f'epoch = {epoch + 1}, training_loss = {training_loss_sum / len(training_data)}, testing_loss = {testing_loss_sum / len(testing_data)}'
    training_losses.append(training_loss_sum / len(training_data))
    testing_losses.append(testing_loss_sum / len(testing_data))
                    progress_write_file.write(f"batch_time: {time.time()-st_time}, avg_batch_loss: {valid_loss/(batch_id+1)}, avg_batch_acc: {valid_acc/(batch_id+1)}\n")
                    progress_write_file.flush()
            print(f"\nEpoch {epoch_id} valid_loss: {valid_loss/(batch_id+1)}")

            # save model, optimizer and test_predictions if val_acc is improved
            if valid_acc>=max_dev_acc:

                # to file
                #name = "model-epoch{}.pth.tar".format(epoch_id)
                name = "model.pth.tar".format(epoch_id)
                torch.save({
                    'epoch_id': epoch_id,
                    'max_dev_acc': max_dev_acc,
                    'argmax_dev_acc': argmax_dev_acc,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict()}, 
                    os.path.join(CHECKPOINT_PATH,name))
                print("Model saved at {} in epoch {}".format(os.path.join(CHECKPOINT_PATH,name),epoch_id))

                # re-assign
                max_dev_acc, argmax_dev_acc = valid_acc, epoch_id
    else:

        #############################################
        # inference
        #############################################

        # load parameters
        model = load_pretrained(model, CHECKPOINT_PATH)

        # infer
def train(args):

    label_name = [
        'not related or not informative', 'other useful information',
        'donations and volunteering', 'affected individuals',
        'sympathy and support', 'infrastructure and utilities damage',
        'caution and advice'
    ]

    device = torch.device("cuda:0" if args['--cuda'] else "cpu")

    prefix = args['MODEL'] + '_' + args['BERT_CONFIG']

    bert_size = args['BERT_CONFIG'].split('-')[1]

    start_time = time.time()
    print('Importing data...', file=sys.stderr)
    df_train = pd.read_csv(args['--train'], index_col=0)
    df_val = pd.read_csv(args['--dev'], index_col=0)
    train_label = dict(df_train.InformationType_label.value_counts())
    label_max = float(max(train_label.values()))
    train_label_weight = torch.tensor(
        [label_max / train_label[i] for i in range(len(train_label))],
        device=device)
    print('Done! time elapsed %.2f sec' % (time.time() - start_time),
          file=sys.stderr)
    print('-' * 80, file=sys.stderr)

    start_time = time.time()
    print('Set up model...', file=sys.stderr)

    if args['MODEL'] == 'default':
        model = DefaultModel(args['BERT_CONFIG'], device, len(label_name))
        optimizer = BertAdam([{
            'params': model.bert.bert.parameters()
        }, {
            'params': model.bert.classifier.parameters(),
            'lr': float(args['--lr'])
        }],
                             lr=float(args['--lr-bert']),
                             max_grad_norm=float(args['--clip-grad']))
    elif args['MODEL'] == 'nonlinear':
        model = NonlinearModel(args['BERT_CONFIG'], device, len(label_name),
                               float(args['--dropout']))
        optimizer = BertAdam([{
            'params': model.bert.parameters()
        }, {
            'params': model.linear1.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.linear2.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.linear3.parameters(),
            'lr': float(args['--lr'])
        }],
                             lr=float(args['--lr-bert']),
                             max_grad_norm=float(args['--clip-grad']))
    elif args['MODEL'] == 'lstm':
        model = CustomBertLSTMModel(args['BERT_CONFIG'],
                                    device,
                                    float(args['--dropout']),
                                    len(label_name),
                                    lstm_hidden_size=int(
                                        args['--hidden-size']))
        optimizer = BertAdam([{
            'params': model.bert.parameters()
        }, {
            'params': model.lstm.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.hidden_to_softmax.parameters(),
            'lr': float(args['--lr'])
        }],
                             lr=float(args['--lr-bert']),
                             max_grad_norm=float(args['--clip-grad']))
    elif args['MODEL'] == 'cnn':
        model = CustomBertConvModel(args['BERT_CONFIG'],
                                    device,
                                    float(args['--dropout']),
                                    len(label_name),
                                    out_channel=int(args['--out-channel']))
        optimizer = BertAdam([{
            'params': model.bert.parameters()
        }, {
            'params': model.conv.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.hidden_to_softmax.parameters(),
            'lr': float(args['--lr'])
        }],
                             lr=float(args['--lr-bert']),
                             max_grad_norm=float(args['--clip-grad']))
    else:
        print('please input a valid model')
        exit(0)

    model = model.to(device)
    print('Use device: %s' % device, file=sys.stderr)
    print('Done! time elapsed %.2f sec' % (time.time() - start_time),
          file=sys.stderr)
    print('-' * 80, file=sys.stderr)

    model.train()

    cn_loss = torch.nn.CrossEntropyLoss(weight=train_label_weight,
                                        reduction='mean')
    torch.save(cn_loss, 'loss_func')  # for later testing

    train_batch_size = int(args['--batch-size'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    model_save_path = prefix + '_model.bin'

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = 0
    cum_examples = report_examples = epoch = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('Begin Maximum Likelihood training...')

    while True:
        epoch += 1

        for sents, targets in batch_iter(df_train,
                                         batch_size=train_batch_size,
                                         shuffle=True,
                                         bert=bert_size):  # for each epoch
            train_iter += 1

            optimizer.zero_grad()

            batch_size = len(sents)

            pre_softmax = model(sents)

            loss = cn_loss(
                pre_softmax,
                torch.tensor(targets, dtype=torch.long, device=device))

            loss.backward()

            optimizer.step()

            batch_losses_val = loss.item() * batch_size
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, '
                      'cum. examples %d, speed %.2f examples/sec, '
                      'time elapsed %.2f sec' %
                      (epoch, train_iter, report_loss / report_examples,
                       cum_examples, report_examples /
                       (time.time() - train_time), time.time() - begin_time),
                      file=sys.stderr)

                train_time = time.time()
                report_loss = report_examples = 0.

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. examples %d' %
                    (epoch, train_iter, cum_loss / cum_examples, cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = 0.

                print('begin validation ...', file=sys.stderr)

                validation_loss = validation(
                    model, df_val, bert_size, cn_loss,
                    device)  # dev batch size can be a bit larger

                print('validation: iter %d, loss %f' %
                      (train_iter, validation_loss),
                      file=sys.stderr)

                is_better = len(
                    hist_valid_scores
                ) == 0 or validation_loss < min(hist_valid_scores)
                hist_valid_scores.append(validation_loss)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)

                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')
                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        print(
                            'load previously best model and decay learning rate to %f%%'
                            % (float(args['--lr-decay']) * 100),
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] *= float(args['--lr-decay'])

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    print('reached maximum number of epochs!', file=sys.stderr)
                    exit(0)
Exemple #8
0
class MTDNNModel(object):
    def __init__(self, opt, state_dict=None, num_train_step=-1):
        self.config = opt
        self.updates = state_dict['updates'] if state_dict and 'updates' in state_dict else 0
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.initial_from_local = True if state_dict else False
        self.network = SANBertNetwork(opt, initial_from_local=self.initial_from_local)
        if state_dict:
            missing_keys, unexpected_keys = self.network.load_state_dict(state_dict['state'], strict=False)
        self.mnetwork = nn.DataParallel(self.network) if opt['multi_gpu_on'] else self.network
        self.total_param = sum([p.nelement() for p in self.network.parameters() if p.requires_grad])
        if opt['cuda']:
            self.network.cuda()
        optimizer_parameters = self._get_param_groups()
        #print(optimizer_parameters)
        self._setup_optim(optimizer_parameters, state_dict, num_train_step) 
        self.para_swapped = False
        self.optimizer.zero_grad()
        self._setup_lossmap(self.config)
        self._setup_kd_lossmap(self.config)
        self._setup_adv_lossmap(self.config)
        self._setup_adv_training(self.config)


    def _setup_adv_training(self, config):
        self.adv_teacher = None
        if config.get('adv_train', False):
            self.adv_teacher = SmartPerturbation(config['adv_epsilon'],
                    config['multi_gpu_on'],
                    config['adv_step_size'],
                    config['adv_noise_var'],
                    config['adv_p_norm'],
                    config['adv_k'],
                    config['fp16'],
                    config['encoder_type'],
                    loss_map=self.adv_task_loss_criterion)


    def _get_param_groups(self):
        no_decay = ['bias', 'gamma', 'beta', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_parameters = [
            {'params': [p for n, p in self.network.named_parameters() if not any(nd in n for nd in no_decay)],
             'weight_decay': 0.01},
            {'params': [p for n, p in self.network.named_parameters() if any(nd in n for nd in no_decay)],
             'weight_decay': 0.0}
        ]
        return optimizer_parameters

    def _setup_optim(self, optimizer_parameters, state_dict=None, num_train_step=-1): ###여기서 Error
        #print(len(optimizer_parameters[0]['params']))
        if self.config['optimizer'] == 'sgd':
            self.optimizer = optim.SGD(optimizer_parameters, self.config['learning_rate'],
                                       weight_decay=self.config['weight_decay'])

        elif self.config['optimizer'] == 'adamax':
            self.optimizer = Adamax(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        elif self.config['optimizer'] == 'radam':
            self.optimizer = RAdam(optimizer_parameters,
                                    self.config['learning_rate'],
                                    warmup=self.config['warmup'],
                                    t_total=num_train_step,
                                    max_grad_norm=self.config['grad_clipping'],
                                    schedule=self.config['warmup_schedule'],
                                    eps=self.config['adam_eps'],
                                    weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
            # The current radam does not support FP16.
            self.config['fp16'] = False
        elif self.config['optimizer'] == 'adam':
            self.optimizer = Adam(optimizer_parameters,
                                  lr=self.config['learning_rate'],
                                  warmup=self.config['warmup'],
                                  t_total=num_train_step,
                                  max_grad_norm=self.config['grad_clipping'],
                                  schedule=self.config['warmup_schedule'],
                                  weight_decay=self.config['weight_decay'])
            if self.config.get('have_lr_scheduler', False): self.config['have_lr_scheduler'] = False
        else:
            raise RuntimeError('Unsupported optimizer: %s' % opt['optimizer'])
        
        if state_dict and 'optimizer' in state_dict:
            #print("Optimizer's state_dict:")
            #state_dict['optimizer']['param_groups'][0]['params']=state_dict['optimizer']['param_groups'][0]['params'][:77]
            #print(len(state_dict['optimizer']['param_groups'][0]['params']))
            #for var_name in state_dict['optimizer']:
            #    print(var_name, "\t", state_dict['optimizer'][var_name])
            #print(self.optimizer.state_dict()) ######
            #state_dict['optimizer'][var_name] =
            self.optimizer.load_state_dict(state_dict['optimizer']) 

        if self.config['fp16']:
            try:
                from apex import amp
                global amp
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, optimizer = amp.initialize(self.network, self.optimizer, opt_level=self.config['fp16_opt_level'])
            self.network = model
            self.optimizer = optimizer

        if self.config.get('have_lr_scheduler', False):
            if self.config.get('scheduler_type', 'rop') == 'rop':
                self.scheduler = ReduceLROnPlateau(self.optimizer, mode='max', factor=self.config['lr_gamma'], patience=3)
            elif self.config.get('scheduler_type', 'rop') == 'exp':
                self.scheduler = ExponentialLR(self.optimizer, gamma=self.config.get('lr_gamma', 0.95))
            else:
                milestones = [int(step) for step in self.config.get('multi_step_lr', '10,20,30').split(',')]
                self.scheduler = MultiStepLR(self.optimizer, milestones=milestones, gamma=self.config.get('lr_gamma'))
        else:
            self.scheduler = None

    def _setup_lossmap(self, config):
        task_def_list: List[TaskDef] = config['task_def_list']
        self.task_loss_criterion = []
        for idx, task_def in enumerate(task_def_list):
            cs = task_def.loss
            lc = LOSS_REGISTRY[cs](name='Loss func of task {}: {}'.format(idx, cs))
            self.task_loss_criterion.append(lc)

    def _setup_kd_lossmap(self, config):
        task_def_list: List[TaskDef] = config['task_def_list']
        self.kd_task_loss_criterion = []
        if config.get('mkd_opt', 0) > 0:
            for idx, task_def in enumerate(task_def_list):
                cs = task_def.kd_loss
                assert cs is not None
                lc = LOSS_REGISTRY[cs](name='KD Loss func of task {}: {}'.format(idx, cs))
                self.kd_task_loss_criterion.append(lc)

    def _setup_adv_lossmap(self, config):
        task_def_list: List[TaskDef] = config['task_def_list']
        self.adv_task_loss_criterion = []
        if config.get('adv_train', False):
            for idx, task_def in enumerate(task_def_list):
                cs = task_def.adv_loss
                assert cs is not None
                lc = LOSS_REGISTRY[cs](name='Adv Loss func of task {}: {}'.format(idx, cs))
                self.adv_task_loss_criterion.append(lc)


    def train(self):
        if self.para_swapped:
            self.para_swapped = False

    def _to_cuda(self, tensor):
        if tensor is None: return tensor

        if isinstance(tensor, list) or isinstance(tensor, tuple):
            y = [e.cuda(non_blocking=True) for e in tensor]
            for e in y:
                e.requires_grad = False
        else:
            y = tensor.cuda(non_blocking=True)
            y.requires_grad = False
        return y

    def update(self, batch_meta, batch_data, weight_alpha): 
        self.network.train()
        y = batch_data[batch_meta['label']]
        y = self._to_cuda(y) if self.config['cuda'] else y

        task_id = batch_meta['task_id']
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        weight = None
        
        if self.config['itw_on']: 
            if self.config['cuda']:
                weight = torch.FloatTensor([batch_meta['weight']]).cuda(non_blocking=True)*weight_alpha
                
            else:
                weight = batch_meta['weight']*weight_alpha
                
        
        """
        if self.config.get('weighted_on', False):
            if self.config['cuda']:
                weight = batch_data[batch_meta['factor']].cuda(non_blocking=True)
            else:
                weight = batch_data[batch_meta['factor']]
        """

        # fw to get logits
        logits = self.mnetwork(*inputs)

        # compute loss
        loss = 0
        if self.task_loss_criterion[task_id] and (y is not None):
            loss_criterion = self.task_loss_criterion[task_id]
            if isinstance(loss_criterion, RankCeCriterion) and batch_meta['pairwise_size'] > 1:
                # reshape the logits for ranking.
                loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1, pairwise_size=batch_meta['pairwise_size'])
            else:
                loss = self.task_loss_criterion[task_id](logits, y, weight, ignore_index=-1)

        # compute kd loss
        if self.config.get('mkd_opt', 0) > 0 and ('soft_label' in batch_meta):
            soft_labels = batch_meta['soft_label']
            soft_labels = self._to_cuda(soft_labels) if self.config['cuda'] else soft_labels
            kd_lc = self.kd_task_loss_criterion[task_id]
            kd_loss = kd_lc(logits, soft_labels, weight, ignore_index=-1) if kd_lc else 0
            loss = loss + kd_loss

        # adv training
        if self.config.get('adv_train', False) and self.adv_teacher:
            # task info
            task_type = batch_meta['task_def']['task_type']
            adv_inputs = [self.mnetwork, logits] + inputs + [task_type, batch_meta.get('pairwise_size', 1)]
            adv_loss = self.adv_teacher.forward(*adv_inputs)
            loss = loss + self.config['adv_alpha'] * adv_loss

        self.train_loss.update(loss.item(), batch_data[batch_meta['token_id']].size(0))
        # scale loss
        loss = loss / self.config.get('grad_accumulation_step', 1)
        if self.config['fp16']:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.get('grad_accumulation_step', 1) == 0:
            if self.config['global_grad_clipping'] > 0:
                if self.config['fp16']:
                    torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                                   self.config['global_grad_clipping'])
                else:
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(),
                                                  self.config['global_grad_clipping'])
            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()

    def encode(self, batch_meta, batch_data):
        self.network.eval()
        inputs = batch_data[:3]
        sequence_output = self.network.encode(*inputs)[0]
        return sequence_output

    # TODO: similar as function extract, preserve since it is used by extractor.py
    # will remove after migrating to transformers package
    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def predict(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta['task_id']
        task_def = TaskDef.from_dict(batch_meta['task_def'])
        task_type = task_def.task_type
        task_obj = tasks.get_task_obj(task_def)
        inputs = batch_data[:batch_meta['input_len']]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if task_obj is not None:
            score, predict = task_obj.test_predict(score)
        elif task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta['pairwise_size'])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta['true_label']
        elif task_type == TaskType.SeqenceLabeling:
            mask = batch_data[batch_meta['mask']]
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
            valied_lenght = mask.sum(1).tolist()
            final_predict = []
            for idx, p in enumerate(predict):
                final_predict.append(p[: valied_lenght[idx]])
            score = score.reshape(-1).tolist()
            return score, final_predict, batch_meta['label']
        elif task_type == TaskType.Span:
            start, end = score
            predictions = []
            if self.config['encoder_type'] == EncoderModelType.BERT:
                import experiments.squad.squad_utils as mrc_utils
                scores, predictions = mrc_utils.extract_answer(batch_meta, batch_data, start, end, self.config.get('max_answer_len', 5), do_lower_case=self.config.get('do_lower_case', False))
            return scores, predictions, batch_meta['answer']
        else:
            raise ValueError("Unknown task_type: %s" % task_type)
        return score, predict, batch_meta['label']

    def save(self, filename):
        network_state = dict([(k, v.cpu()) for k, v in self.network.state_dict().items()])
        params = {
            'state': network_state,
            'optimizer': self.optimizer.state_dict(),
            'config': self.config,
        }
        torch.save(params, filename)
        logger.info('model saved to {}'.format(filename))

    def load(self, checkpoint):
        model_state_dict = torch.load(checkpoint)
        if 'state' in model_state_dict:
            self.network.load_state_dict(model_state_dict['state'], strict=False)
        if 'optimizer' in model_state_dict:
            self.optimizer.load_state_dict(model_state_dict['optimizer'])
        if 'config' in model_state_dict:
            self.config.update(model_state_dict['config'])

    def cuda(self):
        self.network.cuda()
Exemple #9
0
def train_bert_uncased(t_config, p_config, s_config):

    device = torch.device('cuda')
    seed_everything(s_config.seed)

    train = pd.read_csv('../input/train.csv').sample(
        t_config.num_to_load + t_config.valid_size, random_state=s_config.seed)
    train = prepare_train_text(train, p_config)
    train = train.fillna(0)

    tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
    train_processed = get_tokenized_samples(t_config.MAX_SEQUENCE_LENGTH,
                                            tokenizer, train['text_proc'])

    sequences = train_processed
    lengths = np.argmax(sequences == 0, axis=1)
    lengths[lengths == 0] = sequences.shape[1]

    MyModel = BertForSequenceClassification.from_pretrained(
        'bert-base-cased', num_labels=t_config.num_labels)
    MyModel.to(device)

    # Prepare target
    target_train = train['target'].values[:t_config.num_to_load]
    target_train_aux = train[[
        'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat'
    ]].values[:t_config.num_to_load]
    target_train_identity = train[identity_columns].values[:t_config.
                                                           num_to_load]

    # Prepare training data
    inputs_train = train_processed[:t_config.num_to_load]
    weight_train = train['weight'].values[:t_config.num_to_load]
    lengths_train = lengths[:t_config.num_to_load]

    inputs_train = torch.tensor(inputs_train, dtype=torch.int64)
    Target_train = torch.Tensor(target_train)
    Target_train_aux = torch.Tensor(target_train_aux)
    Target_train_identity = torch.Tensor(target_train_identity)
    weight_train = torch.Tensor(weight_train)
    Lengths_train = torch.tensor(lengths_train, dtype=torch.int64)

    # Prepare dataset
    train_dataset = data.TensorDataset(inputs_train, Target_train,
                                       Target_train_aux, Target_train_identity,
                                       weight_train, Lengths_train)

    ids_train = lengths_train.argsort(kind="stable")
    ids_train_new = resort_index(ids_train, t_config.num_of_bucket,
                                 s_config.seed)

    train_loader = torch.utils.data.DataLoader(data.Subset(
        train_dataset, ids_train_new),
                                               batch_size=t_config.batch_size,
                                               collate_fn=clip_to_max_len,
                                               shuffle=False)

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in list(MyModel.named_parameters())
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.01
    }, {
        'params': [
            p for n, p in list(MyModel.named_parameters())
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=t_config.learning_rate,
                         betas=[0.9, 0.999],
                         warmup=t_config.warmup,
                         t_total=t_config.num_epoch * len(train_loader) //
                         t_config.accumulation_steps)

    i = 0
    for n, p in list(MyModel.named_parameters()):
        if i < 10:
            p.requires_grad = False
        i += 1

    p = train['target'].mean()
    likelihood = np.log(p / (1 - p))
    model_bias = torch.tensor(likelihood).type(torch.float)
    MyModel.classifier.bias = nn.Parameter(model_bias.to(device))

    MyModel, optimizer = amp.initialize(MyModel,
                                        optimizer,
                                        opt_level="O1",
                                        verbosity=0)

    for epoch in range(t_config.num_epoch):
        i = 0

        print('Training start')

        optimizer.zero_grad()
        MyModel.train()
        for batch_idx, (input, target, target_aux, target_identity,
                        sample_weight) in tqdm_notebook(
                            enumerate(train_loader), total=len(train_loader)):

            y_pred = MyModel(
                input.to(device),
                attention_mask=(input > 0).to(device),
            )
            loss = F.binary_cross_entropy_with_logits(y_pred[0][:, 0],
                                                      target.to(device),
                                                      reduction='none')
            loss = (loss * sample_weight.to(device)).sum() / (
                sample_weight.sum().to(device))
            loss_aux = F.binary_cross_entropy_with_logits(
                y_pred[0][:, 1:6], target_aux.to(device),
                reduction='none').mean(axis=1)
            loss_aux = (loss_aux * sample_weight.to(device)).sum() / (
                sample_weight.sum().to(device))
            loss += loss_aux
            if t_config.num_labels == 15:
                loss_identity = F.binary_cross_entropy_with_logits(
                    y_pred[0][:, 6:],
                    target_identity.to(device),
                    reduction='none').mean(axis=1)
                loss_identity = (loss_identity * sample_weight.to(device)
                                 ).sum() / (sample_weight.sum().to(device))
                loss += loss_identity

            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()

            if (i + 1) % t_config.accumulation_steps == 0:
                optimizer.step()
                optimizer.zero_grad()

            i += 1

        torch.save(
            {
                'model_state_dict': MyModel.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
            }, f'{t_config.PATH}_{s_config.seed}')
class MTDNNModel(MTDNNPretrainedModel):
    """Instance of an MTDNN Model

    Arguments:
        MTDNNPretrainedModel {BertPretrainedModel} -- Inherited from Bert Pretrained
        config  {MTDNNConfig} -- MTDNN Configuration Object
        pretrained_model_name {str} -- Name of the pretrained model to initial checkpoint
        num_train_step  {int} -- Number of steps to take each training

    Raises:
        RuntimeError: [description]
        ImportError: [description]

    Returns:
        MTDNNModel -- An Instance of an MTDNN Model
    """
    def __init__(
        self,
        config: MTDNNConfig,
        task_defs: MTDNNTaskDefs,
        data_processor: MTDNNDataProcess,
        pretrained_model_name: str = "mtdnn-base-uncased",
        test_datasets_list: list = [],
        output_dir: str = "checkpoint",
    ):

        # Input validation
        assert (
            config.init_checkpoint in self.supported_init_checkpoints()
        ), f"Initial checkpoint must be in {self.supported_init_checkpoints()}"

        num_train_step = data_processor.get_num_all_batches()
        decoder_opts = data_processor.get_decoder_options_list()
        task_types = data_processor.get_task_types_list()
        dropout_list = data_processor.get_tasks_dropout_prob_list()
        loss_types = data_processor.get_loss_types_list()
        kd_loss_types = data_processor.get_kd_loss_types_list()
        tasks_nclass_list = data_processor.get_task_nclass_list()

        # data loaders
        multitask_train_dataloader = data_processor.get_train_dataloader()
        dev_dataloaders_list = data_processor.get_dev_dataloaders()
        test_dataloaders_list = data_processor.get_test_dataloaders()

        assert decoder_opts, "Decoder options list is required!"
        assert task_types, "Task types list is required!"
        assert dropout_list, "Task dropout list is required!"
        assert loss_types, "Loss types list is required!"
        assert kd_loss_types, "KD Loss types list is required!"
        assert tasks_nclass_list, "Tasks nclass list is required!"
        assert (multitask_train_dataloader
                ), "DataLoader for multiple tasks cannot be None"

        super(MTDNNModel, self).__init__(config)

        # Initialize model config and update with training options
        self.config = config
        self.update_config_with_training_opts(
            decoder_opts,
            task_types,
            dropout_list,
            loss_types,
            kd_loss_types,
            tasks_nclass_list,
        )
        wandb.init(project='mtl-uncertainty-final',
                   entity='feifang24',
                   config=self.config.to_dict())
        self.tasks = data_processor.tasks  # {task_name: task_idx}
        self.task_defs = task_defs
        self.multitask_train_dataloader = multitask_train_dataloader
        self.dev_dataloaders_list = dev_dataloaders_list
        self.test_dataloaders_list = test_dataloaders_list
        self.test_datasets_list = self._configure_test_ds(test_datasets_list)
        self.output_dir = output_dir

        self.batch_bald = BatchBALD(num_samples=10,
                                    num_draw=500,
                                    shuffle_prop=0.0,
                                    reverse=True,
                                    reduction='mean')
        self.loss_weights = [None] * self.num_tasks

        # Create the output_dir if it's doesn't exist
        MTDNNCommonUtils.create_directory_if_not_exists(self.output_dir)

        self.pooler = None

        # Resume from model checkpoint
        if self.config.resume and self.config.model_ckpt:
            assert os.path.exists(
                self.config.model_ckpt), "Model checkpoint does not exist"
            logger.info(f"loading model from {self.config.model_ckpt}")
            self = self.load(self.config.model_ckpt)
            return

        # Setup the baseline network
        # - Define the encoder based on config options
        # - Set state dictionary based on configuration setting
        # - Download pretrained model if flag is set
        # TODO - Use Model.pretrained_model() after configuration file is hosted.
        if self.config.use_pretrained_model:
            with MTDNNCommonUtils.download_path() as file_path:
                path = pathlib.Path(file_path)
                self.local_model_path = MTDNNCommonUtils.maybe_download(
                    url=self.
                    pretrained_model_archive_map[pretrained_model_name],
                    log=logger,
                )
            self.bert_model = MTDNNCommonUtils.load_pytorch_model(
                self.local_model_path)
            self.state_dict = self.bert_model["state"]
        else:
            # Set the config base on encoder type set for initial checkpoint
            if config.encoder_type == EncoderModelType.BERT:
                self.bert_config = BertConfig.from_dict(self.config.to_dict())
                self.bert_model = BertModel.from_pretrained(
                    self.config.init_checkpoint)
                self.state_dict = self.bert_model.state_dict()
                self.config.hidden_size = self.bert_config.hidden_size
            if config.encoder_type == EncoderModelType.ROBERTA:
                # Download and extract from PyTorch hub if not downloaded before
                self.bert_model = torch.hub.load("pytorch/fairseq",
                                                 config.init_checkpoint)
                self.config.hidden_size = self.bert_model.args.encoder_embed_dim
                self.pooler = LinearPooler(self.config.hidden_size)
                new_state_dict = {}
                for key, val in self.bert_model.state_dict().items():
                    if key.startswith("model.decoder.sentence_encoder"
                                      ) or key.startswith(
                                          "model.classification_heads"):
                        key = f"bert.{key}"
                        new_state_dict[key] = val
                    # backward compatibility PyTorch <= 1.0.0
                    if key.startswith("classification_heads"):
                        key = f"bert.model.{key}"
                        new_state_dict[key] = val
                self.state_dict = new_state_dict

        self.updates = (self.state_dict["updates"] if self.state_dict
                        and "updates" in self.state_dict else 0)
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.train_loss_by_task = [
            AverageMeter() for _ in range(len(self.tasks))
        ]
        self.network = SANBERTNetwork(
            init_checkpoint_model=self.bert_model,
            pooler=self.pooler,
            config=self.config,
        )
        if self.state_dict:
            self.network.load_state_dict(self.state_dict, strict=False)
        self.mnetwork = (nn.DataParallel(self.network)
                         if self.config.multi_gpu_on else self.network)
        self.total_param = sum([
            p.nelement() for p in self.network.parameters() if p.requires_grad
        ])

        # Move network to GPU if device available and flag set
        if self.config.cuda:
            self.network.cuda(device=self.config.cuda_device)
        self.optimizer_parameters = self._get_param_groups()
        self._setup_optim(self.optimizer_parameters, self.state_dict,
                          num_train_step)
        self.para_swapped = False
        self.optimizer.zero_grad()
        self._setup_lossmap()

    @property
    def num_tasks(self):
        return len(self.tasks)

    def _configure_test_ds(self, test_datasets_list):
        if test_datasets_list: return test_datasets_list
        result = []
        for task in self.task_defs.get_task_names():
            if task == 'mnli':
                result.append('mnli_matched')
                result.append('mnli_mismatched')
            else:
                result.append(task)
        return result

    def _get_param_groups(self):
        no_decay = [
            "bias", "gamma", "beta", "LayerNorm.bias", "LayerNorm.weight"
        ]
        optimizer_parameters = [
            {
                "params": [
                    p for n, p in self.network.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.01,
            },
            {
                "params": [
                    p for n, p in self.network.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        return optimizer_parameters

    def _setup_optim(self,
                     optimizer_parameters,
                     state_dict: dict = None,
                     num_train_step: int = -1):

        # Setup optimizer parameters
        if self.config.optimizer == "sgd":
            self.optimizer = optim.SGD(
                optimizer_parameters,
                self.config.learning_rate,
                weight_decay=self.config.weight_decay,
            )
        elif self.config.optimizer == "adamax":
            self.optimizer = Adamax(
                optimizer_parameters,
                self.config.learning_rate,
                warmup=self.config.warmup,
                t_total=num_train_step,
                max_grad_norm=self.config.grad_clipping,
                schedule=self.config.warmup_schedule,
                weight_decay=self.config.weight_decay,
            )

        elif self.config.optimizer == "radam":
            self.optimizer = RAdam(
                optimizer_parameters,
                self.config.learning_rate,
                warmup=self.config.warmup,
                t_total=num_train_step,
                max_grad_norm=self.config.grad_clipping,
                schedule=self.config.warmup_schedule,
                eps=self.config.adam_eps,
                weight_decay=self.config.weight_decay,
            )

            # The current radam does not support FP16.
            self.config.fp16 = False
        elif self.config.optimizer == "adam":
            self.optimizer = Adam(
                optimizer_parameters,
                lr=self.config.learning_rate,
                warmup=self.config.warmup,
                t_total=num_train_step,
                max_grad_norm=self.config.grad_clipping,
                schedule=self.config.warmup_schedule,
                weight_decay=self.config.weight_decay,
            )

        else:
            raise RuntimeError(
                f"Unsupported optimizer: {self.config.optimizer}")

        # Clear scheduler for certain optimizer choices
        if self.config.optimizer in ["adam", "adamax", "radam"]:
            if self.config.have_lr_scheduler:
                self.config.have_lr_scheduler = False

        if state_dict and "optimizer" in state_dict:
            self.optimizer.load_state_dict(state_dict["optimizer"])

        if self.config.fp16:
            try:
                global amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                self.network,
                self.optimizer,
                opt_level=self.config.fp16_opt_level)
            self.network = model
            self.optimizer = optimizer

        if self.config.have_lr_scheduler:
            if self.config.scheduler_type == "rop":
                self.scheduler = ReduceLROnPlateau(self.optimizer,
                                                   mode="max",
                                                   factor=self.config.lr_gamma,
                                                   patience=3)
            elif self.config.scheduler_type == "exp":
                self.scheduler = ExponentialLR(self.optimizer,
                                               gamma=self.config.lr_gamma
                                               or 0.95)
            else:
                milestones = [
                    int(step) for step in (
                        self.config.multi_step_lr or "10,20,30").split(",")
                ]
                self.scheduler = MultiStepLR(self.optimizer,
                                             milestones=milestones,
                                             gamma=self.config.lr_gamma)
        else:
            self.scheduler = None

    def _setup_lossmap(self):
        self.task_loss_criterion = []
        for idx, cs in enumerate(self.config.loss_types):
            assert cs is not None, "Loss type must be defined."
            lc = LOSS_REGISTRY[cs](name=f"Loss func of task {idx}: {cs}")
            self.task_loss_criterion.append(lc)

    def _setup_kd_lossmap(self):
        loss_types = self.config.kd_loss_types
        self.kd_task_loss_criterion = []
        if self.config.mkd_opt > 0:
            for idx, cs in enumerate(loss_types):
                assert cs, "Loss type must be defined."
                lc = LOSS_REGISTRY[cs](
                    name="Loss func of task {}: {}".format(idx, cs))
                self.kd_task_loss_criterion.append(lc)

    def _to_cuda(self, tensor):
        # Set tensor to gpu (non-blocking) if a PyTorch tensor
        if tensor is None:
            return tensor

        if isinstance(tensor, list) or isinstance(tensor, tuple):
            y = [
                e.cuda(device=self.config.cuda_device, non_blocking=True)
                for e in tensor
            ]
            for t in y:
                t.requires_grad = False
        else:
            y = tensor.cuda(device=self.config.cuda_device, non_blocking=True)
            y.requires_grad = False
        return y

    def train(self):
        if self.para_swapped:
            self.para_swapped = False

    def update(self, batch_meta, batch_data):
        self.network.train()
        target = batch_data[batch_meta["label"]]
        soft_labels = None

        task_type = batch_meta["task_type"]
        target = self._to_cuda(target) if self.config.cuda else target

        task_id = batch_meta["task_id"]
        inputs = batch_data[:batch_meta["input_len"]]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        weight = self.loss_weights[task_id]
        if self.config.weighted_on:
            if self.config.cuda:
                weight = batch_data[batch_meta["factor"]].cuda(
                    device=self.config.cuda_device, non_blocking=True)
            else:
                weight = batch_data[batch_meta["factor"]]
        logits = self.mnetwork(*inputs)

        # compute loss
        loss = 0
        if self.task_loss_criterion[task_id] and (target is not None):
            loss = self.task_loss_criterion[task_id](logits,
                                                     target,
                                                     weight,
                                                     ignore_index=-1)

        # compute kd loss
        if self.config.mkd_opt > 0 and ("soft_label" in batch_meta):
            soft_labels = batch_meta["soft_label"]
            soft_labels = (self._to_cuda(soft_labels)
                           if self.config.cuda else soft_labels)
            kd_lc = self.kd_task_loss_criterion[task_id]
            kd_loss = (kd_lc(logits, soft_labels, weight, ignore_index=-1)
                       if kd_lc else 0)
            loss = loss + kd_loss

        self.train_loss_by_task[task_id].update(
            loss.item() / (self.loss_weights[task_id]
                           if self.loss_weights[task_id] is not None else 1.),
            batch_data[batch_meta["token_id"]].size(0))
        self.train_loss.update(loss.item(),
                               batch_data[batch_meta["token_id"]].size(0))
        # scale loss
        loss = loss / (self.config.grad_accumulation_step or 1)
        if self.config.fp16:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.grad_accumulation_step == 0:
            if self.config.global_grad_clipping > 0:
                if self.config.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer),
                        self.config.global_grad_clipping,
                    )
                else:
                    torch.nn.utils.clip_grad_norm_(
                        self.network.parameters(),
                        self.config.global_grad_clipping)
            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()

    def eval_mode(self,
                  data: DataLoader,
                  metric_meta,
                  use_cuda=True,
                  with_label=True,
                  label_mapper=None,
                  task_type=TaskType.Classification):
        eval_loss = AverageMeter()
        if use_cuda:
            self.cuda()
        predictions = []
        golds = []
        scores = []
        uncertainties = []
        ids = []
        metrics = {}
        for idx, (batch_info, batch_data) in enumerate(data):
            if idx % 100 == 0:
                logger.info(f"predicting {idx}")
            batch_info, batch_data = MTDNNCollater.patch_data(
                use_cuda, batch_info, batch_data)
            score, pred, gold, loss, uncertainty = self._predict_batch(
                batch_info, batch_data)
            predictions.extend(pred)
            golds.extend(gold)
            scores.extend(score)
            uncertainties.extend(uncertainty)
            ids.extend(batch_info["uids"])
            eval_loss.update(loss.item(), len(batch_info["uids"]))

        if task_type == TaskType.Span:
            golds = merge_answers(ids, golds)
            predictions, scores = select_answers(ids, predictions, scores)
        if with_label:
            metrics = calc_metrics(metric_meta, golds, predictions, scores,
                                   label_mapper)
        return metrics, predictions, scores, golds, ids, (
            eval_loss.avg, eval_loss.count), np.mean(uncertainties)

    def _predict_batch(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta["task_id"]
        task_type = batch_meta["task_type"]
        inputs = batch_data[:batch_meta["input_len"]]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)

        # get logits (and val loss if we have labels)
        label = batch_meta["label"]
        target = batch_data[label] if type(label) is int else torch.tensor(
            label)
        target = self._to_cuda(target) if self.config.cuda else target

        weight = None
        if self.config.weighted_on:
            if self.config.cuda:
                weight = batch_data[batch_meta["factor"]].cuda(
                    device=self.config.cuda_device, non_blocking=True)
            else:
                weight = batch_data[batch_meta["factor"]]

        score = self.mnetwork(*inputs)
        if self.config.mc_dropout_samples > 0:

            def apply_dropout(m):
                if isinstance(m, DropoutWrapper):
                    m.train()

            self.network.apply(apply_dropout)
            mc_sample_scores = torch.stack([
                self.mnetwork(*inputs)
                for _ in range(self.config.mc_dropout_samples)
            ], -1)
            mc_sample_scores = F.softmax(mc_sample_scores,
                                         dim=1).data.cpu().numpy()
            uncertainty = self.batch_bald.get_uncertainties(mc_sample_scores)
        else:
            uncertainty = 1.0

        loss = None
        if self.task_loss_criterion[task_id] and (target is not None):
            loss = self.task_loss_criterion[task_id](score,
                                                     target,
                                                     weight,
                                                     ignore_index=-1)

        if task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta["pairwise_size"])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta["true_label"], loss
        elif task_type == TaskType.SequenceLabeling:
            mask = batch_data[batch_meta["mask"]]
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
            valied_lenght = mask.sum(1).tolist()
            final_predict = []
            for idx, p in enumerate(predict):
                final_predict.append(p[:valied_lenght[idx]])
            score = score.reshape(-1).tolist()
            return score, final_predict, batch_meta["label"], loss
        elif task_type == TaskType.Span:
            start, end = score
            predictions = []
            if self.config.encoder_type == EncoderModelType.BERT:
                scores, predictions = extract_answer(
                    batch_meta,
                    batch_data,
                    start,
                    end,
                    self.config.get("max_answer_len", 5),
                )
            return scores, predictions, batch_meta["answer"], loss
        else:
            if task_type == TaskType.Classification:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta["label"], loss, uncertainty

    def _rerank_batches(self,
                        batches,
                        start_idx,
                        task_id_to_weights,
                        softmax_task_weights=False):
        def weights_to_probs(weights):
            if softmax_task_weights:
                probs = softmax(weights)
            else:
                probs = weights / np.sum(weights)
            return probs

        # reshuffle all batches; sort them by task_id
        new_batches = [list(self.multitask_train_dataloader) for _ in range(5)]
        for i in range(len(new_batches)):
            random.shuffle(new_batches[i])  # this line somehow helps?
        new_batches = [b for batches in new_batches
                       for b in batches]  # flatten
        task_id_by_batch = [
            batch_meta["task_id"] for batch_meta, _ in new_batches
        ]
        batches_by_task = [[] for _ in range(self.num_tasks)]
        for batch_idx, task_id in enumerate(task_id_by_batch):
            batches_by_task[task_id].append(batch_idx)

        task_probs = weights_to_probs(task_id_to_weights)

        # multiply weight by num batches per task
        # task_probs = weights_to_probs(task_id_to_weights * np.asarray([len(batches) for batches in batches_by_task]))  # comment out as see fit

        if self.config.uncertainty_based_weight:
            rel_loss_weights = (1. / task_id_to_weights)
            self.loss_weights = (rel_loss_weights * self.num_tasks / np.sum(rel_loss_weights)) * \
                                    (np.mean(self.dev_loss_by_task) / self.dev_loss_by_task)
            # self.loss_weights = rel_loss_weights * np.mean(task_id_to_weights)

        num_batches = len(batches[start_idx:])
        # sample num_batches many tasks w/ replacement
        task_indices_sampled = np.random.choice(self.num_tasks,
                                                num_batches,
                                                replace=True,
                                                p=task_probs)

        reranked_batches = [None] * num_batches
        counters = [0] * self.num_tasks
        for i, task_id in enumerate(task_indices_sampled):
            batch_idx = batches_by_task[task_id][counters[task_id] %
                                                 len(batches_by_task[task_id])]
            counters[task_id] += 1
            reranked_batches[i] = new_batches[batch_idx]

        weights_by_task_name = {}
        for task_name, task_id in self.tasks.items():
            weights_by_task_name[f'task_weight/{task_name}'] = task_probs[
                task_id]

        return [None] * start_idx + reranked_batches, weights_by_task_name

    def fit(self, epochs=0):
        """ Fit model to training datasets """
        epochs = epochs or self.config.epochs
        logger.info(f"Total number of params: {self.total_param}")
        FIRST_STEP_TO_LOG = 10
        for epoch in range(1, epochs + 1):
            logger.info(f"At epoch {epoch}")
            logger.info(
                f"Amount of data to go over: {len(self.multitask_train_dataloader)}"
            )

            start = datetime.now()
            # Create batches and train
            batches = list(self.multitask_train_dataloader)
            if self.config.uncertainty_based_sampling and epoch > 1:
                batches, weights_by_task_name = self._rerank_batches(
                    batches,
                    start_idx=0,
                    task_id_to_weights=self.smoothed_uncertainties_by_task)
            for idx in range(len(batches)):
                batch_meta, batch_data = batches[idx]
                batch_meta, batch_data = MTDNNCollater.patch_data(
                    self.config.cuda, batch_meta, batch_data)

                task_id = batch_meta["task_id"]
                self.update(batch_meta, batch_data)
                if (self.local_updates == FIRST_STEP_TO_LOG
                        or (self.local_updates) %
                    (self.config.log_per_updates *
                     self.config.grad_accumulation_step) == 0):

                    time_left = str((datetime.now() - start) / (idx + 1) *
                                    (len(self.multitask_train_dataloader) -
                                     idx - 1)).split(".")[0]
                    logger.info(
                        "Updates - [{0:6}] Training Loss - [{1:.5f}] Time Remaining - [{2}]"
                        .format(self.updates, self.train_loss.avg, time_left))
                    val_logs, uncertainties_by_task = self._eval(
                        epoch, save_scores=False, eval_type='dev')
                    test_logs, _ = self._eval(epoch,
                                              save_scores=False,
                                              eval_type='test')
                    if self.local_updates == FIRST_STEP_TO_LOG:
                        weights_by_task_name = {
                            f'task_weight/{task_name}': 1.0
                            for task_name in self.tasks
                        }
                    else:
                        if self.local_updates == self.config.log_per_updates * self.config.grad_accumulation_step:
                            self.smoothed_uncertainties_by_task = uncertainties_by_task
                            self.initial_train_loss_by_task = np.asarray(
                                [loss.avg for loss in self.train_loss_by_task])
                        else:
                            alpha = self.config.smooth_uncertainties
                            self.smoothed_uncertainties_by_task = alpha * self.smoothed_uncertainties_by_task + \
                                                                    (1 - alpha) * uncertainties_by_task
                        if self.config.uncertainty_based_sampling and idx < len(
                                batches) - 1:
                            batches, weights_by_task_name = self._rerank_batches(
                                batches,
                                start_idx=idx + 1,
                                task_id_to_weights=self.
                                smoothed_uncertainties_by_task)
                        if self.config.rate_based_weight:
                            current_train_loss_by_task = np.asarray(
                                [loss.avg for loss in self.train_loss_by_task])
                            rate_of_training_by_task = current_train_loss_by_task / self.initial_train_loss_by_task
                            self.loss_weights = (rate_of_training_by_task / np.mean(rate_of_training_by_task)) * \
                                                    (np.mean(current_train_loss_by_task) / current_train_loss_by_task)
                    self._log_training({
                        **val_logs,
                        **test_logs,
                        **weights_by_task_name
                    })

                if self.config.save_per_updates_on and (
                    (self.local_updates) %
                    (self.config.save_per_updates *
                     self.config.grad_accumulation_step) == 0):
                    model_file = os.path.join(
                        self.output_dir,
                        "model_{}_{}.pt".format(epoch, self.updates),
                    )
                    logger.info(f"Saving mt-dnn model to {model_file}")
                    self.save(model_file)

            # Eval and save checkpoint after each epoch
            logger.info('=' * 5 + f' End of EPOCH {epoch} ' + '=' * 5)
            logger.info(f'Train loss (epoch avg): {self.train_loss.avg}')
            val_logs, uncertainties_by_task = self._eval(epoch,
                                                         save_scores=True,
                                                         eval_type='dev')
            test_logs, _ = self._eval(epoch,
                                      save_scores=True,
                                      eval_type='test')
            self._log_training({
                **val_logs,
                **test_logs,
                **weights_by_task_name
            })

            # model_file = os.path.join(self.output_dir, "model_{}.pt".format(epoch))
            # logger.info(f"Saving mt-dnn model to {model_file}")
            # self.save(model_file)

    def _eval(self, epoch, save_scores, eval_type='dev'):
        if eval_type not in {'dev', 'test'}:
            raise ValueError(
                "eval_type must be one of the following: 'dev' or 'test'.")
        is_dev = eval_type == 'dev'

        log_dict = {}
        loss_agg = AverageMeter()
        loss_by_task = {}
        uncertainties_by_task = {}
        for idx, dataset in enumerate(self.test_datasets_list):
            logger.info(
                f"Evaluating on {eval_type} ds {idx}: {dataset.upper()}")
            prefix = dataset.split("_")[0]
            results = self._predict(idx,
                                    prefix,
                                    dataset,
                                    eval_type=eval_type,
                                    saved_epoch_idx=epoch,
                                    save_scores=save_scores)

            avg_loss = results['avg_loss']
            num_samples = results['num_samples']
            loss_agg.update(avg_loss, n=num_samples)
            loss_by_task[dataset] = avg_loss
            if is_dev:
                logger.info(
                    f"Task {dataset} -- {eval_type} loss: {avg_loss:.3f}")

            metrics = results['metrics']
            for key, val in metrics.items():
                if is_dev:
                    logger.info(
                        f"Task {dataset} -- {eval_type} {key}: {val:.3f}")
                log_dict[f'{dataset}/{eval_type}_{key}'] = val

            uncertainty = results['uncertainty']
            if is_dev:
                logger.info(
                    f"Task {dataset} -- {eval_type} uncertainty: {uncertainty:.3f}"
                )
            log_dict[
                f'{eval_type}_uncertainty_by_task/{dataset}'] = uncertainty
            if prefix not in uncertainties_by_task:
                uncertainties_by_task[prefix] = uncertainty
            else:
                # exploiting the fact that only mnli has two dev sets
                uncertainties_by_task[prefix] += uncertainty
                uncertainties_by_task[prefix] /= 2
        if is_dev: logger.info(f'{eval_type} loss: {loss_agg.avg}')
        log_dict[f'{eval_type}_loss'] = loss_agg.avg
        log_dict.update({
            f'{eval_type}_loss_by_task/{task}': loss
            for task, loss in loss_by_task.items()
        })

        loss_by_task_id = [None] * self.num_tasks
        for task_name, loss in loss_by_task.items():
            loss_by_task_id[self.tasks[task_name]] = loss
        loss_by_task_id = np.asarray(loss_by_task_id)

        if is_dev:
            self.dev_loss_by_task = loss_by_task_id
        else:
            self.test_loss_by_task = loss_by_task_id

        # convert uncertainties_by_task from dict to list, where list[i] = weight of task_id i
        uncertainties_by_task_id = [None] * self.num_tasks
        for task_name, weight in uncertainties_by_task.items():
            task_id = self.tasks[task_name]
            uncertainties_by_task_id[task_id] = weight
        uncertainties_by_task_id = np.asarray(uncertainties_by_task_id)

        return log_dict, uncertainties_by_task_id

    def _log_training(self, val_logs):
        train_loss_by_task = {
            f'train_loss_by_task/{task}': self.train_loss_by_task[task_idx].avg
            for task, task_idx in self.tasks.items()
        }
        train_loss_agg = {'train_loss': self.train_loss.avg}
        loss_weights_by_task = {}
        if self.config.uncertainty_based_weight or self.config.rate_based_weight:
            for task_name, task_id in self.tasks.items():
                loss_weights_by_task[
                    f'loss_weight/{task_name}'] = self.loss_weights[
                        task_id] if self.loss_weights[
                            task_id] is not None else 1.
        log_dict = {
            'global_step': self.updates,
            **train_loss_by_task,
            **train_loss_agg,
            **val_logs,
            **loss_weights_by_task
        }
        wandb.log(log_dict)

    def _predict(self,
                 eval_ds_idx,
                 eval_ds_prefix,
                 eval_ds_name,
                 eval_type='dev',
                 saved_epoch_idx=None,
                 save_scores=True):
        if eval_type not in {'dev', 'test'}:
            raise ValueError(
                "eval_type must be one of the following: 'dev' or 'test'.")
        is_dev = eval_type == 'dev'

        label_dict = self.task_defs.global_map.get(eval_ds_prefix, None)

        if is_dev:
            data: DataLoader = self.dev_dataloaders_list[eval_ds_idx]
        else:
            data: DataLoader = self.test_dataloaders_list[eval_ds_idx]

        if data is None:
            results = None
        else:
            with torch.no_grad():
                (
                    metrics,
                    predictions,
                    scores,
                    golds,
                    ids,
                    (eval_ds_avg_loss, eval_ds_num_samples),
                    uncertainty,
                ) = self.eval_mode(
                    data,
                    metric_meta=self.task_defs.metric_meta_map[eval_ds_prefix],
                    use_cuda=self.config.cuda,
                    with_label=True,
                    label_mapper=label_dict,
                    task_type=self.task_defs.task_type_map[eval_ds_prefix])
            results = {
                "metrics": metrics,
                "predictions": predictions,
                "uids": ids,
                "scores": scores,
                "uncertainty": uncertainty
            }
            if save_scores:
                score_file_prefix = f"{eval_ds_name}_{eval_type}_scores" \
                                    + (f'_{saved_epoch_idx}' if saved_epoch_idx is not None else "")
                score_file = os.path.join(self.output_dir,
                                          score_file_prefix + ".json")
                MTDNNCommonUtils.dump(score_file, results)
                if self.config.use_glue_format:
                    official_score_file = os.path.join(
                        self.output_dir, score_file_prefix + ".tsv")
                    submit(official_score_file, results, label_dict)

            results.update({
                "avg_loss": eval_ds_avg_loss,
                "num_samples": eval_ds_num_samples
            })

        return results

    def predict(self, trained_model_chckpt: str = None):
        """
        Inference of model on test datasets
        """
        # Load a trained checkpoint if a valid model checkpoint
        if trained_model_chckpt and gfile.exists(trained_model_chckpt):
            logger.info(
                f"Running predictions using: {trained_model_chckpt}. This may take 3 minutes."
            )
            self.load(trained_model_chckpt)
            logger.info("Checkpoint loaded.")

        self.config.batch_size_eval = 128
        self.config.use_glue_format = True

        # test eval
        for idx, dataset in enumerate(self.test_datasets_list):
            prefix = dataset.split("_")[0]
            results = self._predict(idx, prefix, dataset, eval_type='test')
            if results:
                logger.info(f"[new test scores saved for {dataset}.]")
            else:
                logger.info(f"Data not found for {dataset}.")

    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def save(self, filename):
        network_state = dict([(k, v.cpu())
                              for k, v in self.network.state_dict().items()])
        params = {
            "state": network_state,
            "optimizer": self.optimizer.state_dict(),
            "config": self.config,
        }
        torch.save(params, gfile.GFile(filename, mode='wb'))
        logger.info("model saved to {}".format(filename))

    def load(self, checkpoint):
        model_state_dict = torch.load(gfile.GFile(checkpoint, mode='rb'))
        self.network.load_state_dict(model_state_dict["state"], strict=False)
        self.optimizer.load_state_dict(model_state_dict["optimizer"])
        self.config = model_state_dict["config"]

    def cuda(self):
        self.network.cuda(device=self.config.cuda_device)

    def supported_init_checkpoints(self):
        """List of allowed check points
        """
        return [
            "bert-base-uncased",
            "bert-base-cased",
            "bert-large-uncased",
            "mtdnn-base-uncased",
            "mtdnn-large-uncased",
            "roberta.base",
            "roberta.large",
        ]

    def update_config_with_training_opts(
        self,
        decoder_opts,
        task_types,
        dropout_list,
        loss_types,
        kd_loss_types,
        tasks_nclass_list,
    ):
        # Update configurations with options obtained from preprocessing training data
        setattr(self.config, "decoder_opts", decoder_opts)
        setattr(self.config, "task_types", task_types)
        setattr(self.config, "tasks_dropout_p", dropout_list)
        setattr(self.config, "loss_types", loss_types)
        setattr(self.config, "kd_loss_types", kd_loss_types)
        setattr(self.config, "tasks_nclass_list", tasks_nclass_list)
Exemple #11
0
class MTDNNModel(MTDNNPretrainedModel):
    """Instance of an MTDNN Model
    
    Arguments:
        MTDNNPretrainedModel {BertPretrainedModel} -- Inherited from Bert Pretrained
        config  {MTDNNConfig} -- MTDNN Configuration Object 
        pretrained_model_name {str} -- Name of the pretrained model to initial checkpoint
        num_train_step  {int} -- Number of steps to take each training
    
    Raises:
        RuntimeError: [description]
        ImportError: [description]
    
    Returns:
        MTDNNModel -- An Instance of an MTDNN Model
    """
    def __init__(
        self,
        config: MTDNNConfig,
        task_defs: MTDNNTaskDefs,
        pretrained_model_name: str = "mtdnn-base-uncased",
        num_train_step: int = -1,
        decoder_opts: list = None,
        task_types: list = None,
        dropout_list: list = None,
        loss_types: list = None,
        kd_loss_types: list = None,
        tasks_nclass_list: list = None,
        multitask_train_dataloader: DataLoader = None,
        dev_dataloaders_list: list = None,  # list of dataloaders
        test_dataloaders_list: list = None,  # list of dataloaders
        test_datasets_list: list = ["mnli_mismatched", "mnli_matched"],
        output_dir: str = "checkpoint",
        log_dir: str = "tensorboard_logdir",
    ):

        # Input validation
        assert (
            config.init_checkpoint in self.supported_init_checkpoints()
        ), f"Initial checkpoint must be in {self.supported_init_checkpoints()}"

        assert decoder_opts, "Decoder options list is required!"
        assert task_types, "Task types list is required!"
        assert dropout_list, "Task dropout list is required!"
        assert loss_types, "Loss types list is required!"
        assert kd_loss_types, "KD Loss types list is required!"
        assert tasks_nclass_list, "Tasks nclass list is required!"
        assert (multitask_train_dataloader
                ), "DataLoader for multiple tasks cannot be None"
        assert test_datasets_list, "Pass a list of test dataset prefixes"

        super(MTDNNModel, self).__init__(config)

        # Initialize model config and update with training options
        self.config = config
        self.update_config_with_training_opts(
            decoder_opts,
            task_types,
            dropout_list,
            loss_types,
            kd_loss_types,
            tasks_nclass_list,
        )
        self.task_defs = task_defs
        self.multitask_train_dataloader = multitask_train_dataloader
        self.dev_dataloaders_list = dev_dataloaders_list
        self.test_dataloaders_list = test_dataloaders_list
        self.test_datasets_list = test_datasets_list
        self.output_dir = output_dir
        self.log_dir = log_dir

        # Create the output_dir if it's doesn't exist
        MTDNNCommonUtils.create_directory_if_not_exists(self.output_dir)
        self.tensor_board = SummaryWriter(log_dir=self.log_dir)

        self.pooler = None

        # Resume from model checkpoint
        if self.config.resume and self.config.model_ckpt:
            assert os.path.exists(
                self.config.model_ckpt), "Model checkpoint does not exist"
            logger.info(f"loading model from {self.config.model_ckpt}")
            self = self.load(self.config.model_ckpt)
            return

        # Setup the baseline network
        # - Define the encoder based on config options
        # - Set state dictionary based on configuration setting
        # - Download pretrained model if flag is set
        # TODO - Use Model.pretrained_model() after configuration file is hosted.
        if self.config.use_pretrained_model:
            with MTDNNCommonUtils.download_path() as file_path:
                path = pathlib.Path(file_path)
                self.local_model_path = MTDNNCommonUtils.maybe_download(
                    url=self.
                    pretrained_model_archive_map[pretrained_model_name],
                    log=logger,
                )
            self.bert_model = MTDNNCommonUtils.load_pytorch_model(
                self.local_model_path)
            self.state_dict = self.bert_model["state"]
        else:
            # Set the config base on encoder type set for initial checkpoint
            if config.encoder_type == EncoderModelType.BERT:
                self.bert_config = BertConfig.from_dict(self.config.to_dict())
                self.bert_model = BertModel.from_pretrained(
                    self.config.init_checkpoint)
                self.state_dict = self.bert_model.state_dict()
                self.config.hidden_size = self.bert_config.hidden_size
            if config.encoder_type == EncoderModelType.ROBERTA:
                # Download and extract from PyTorch hub if not downloaded before
                self.bert_model = torch.hub.load("pytorch/fairseq",
                                                 config.init_checkpoint)
                self.config.hidden_size = self.bert_model.args.encoder_embed_dim
                self.pooler = LinearPooler(self.config.hidden_size)
                new_state_dict = {}
                for key, val in self.bert_model.state_dict().items():
                    if key.startswith("model.decoder.sentence_encoder"
                                      ) or key.startswith(
                                          "model.classification_heads"):
                        key = f"bert.{key}"
                        new_state_dict[key] = val
                    # backward compatibility PyTorch <= 1.0.0
                    if key.startswith("classification_heads"):
                        key = f"bert.model.{key}"
                        new_state_dict[key] = val
                self.state_dict = new_state_dict

        self.updates = (self.state_dict["updates"] if self.state_dict
                        and "updates" in self.state_dict else 0)
        self.local_updates = 0
        self.train_loss = AverageMeter()
        self.network = SANBERTNetwork(
            init_checkpoint_model=self.bert_model,
            pooler=self.pooler,
            config=self.config,
        )
        if self.state_dict:
            self.network.load_state_dict(self.state_dict, strict=False)
        self.mnetwork = (nn.DataParallel(self.network)
                         if self.config.multi_gpu_on else self.network)
        self.total_param = sum([
            p.nelement() for p in self.network.parameters() if p.requires_grad
        ])

        # Move network to GPU if device available and flag set
        if self.config.cuda:
            self.network.cuda(device=self.config.cuda_device)
        self.optimizer_parameters = self._get_param_groups()
        self._setup_optim(self.optimizer_parameters, self.state_dict,
                          num_train_step)
        self.para_swapped = False
        self.optimizer.zero_grad()
        self._setup_lossmap()

    def _get_param_groups(self):
        no_decay = [
            "bias", "gamma", "beta", "LayerNorm.bias", "LayerNorm.weight"
        ]
        optimizer_parameters = [
            {
                "params": [
                    p for n, p in self.network.named_parameters()
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.01,
            },
            {
                "params": [
                    p for n, p in self.network.named_parameters()
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]
        return optimizer_parameters

    def _setup_optim(self,
                     optimizer_parameters,
                     state_dict: dict = None,
                     num_train_step: int = -1):

        # Setup optimizer parameters
        if self.config.optimizer == "sgd":
            self.optimizer = optim.SGD(
                optimizer_parameters,
                self.config.learning_rate,
                weight_decay=self.config.weight_decay,
            )
        elif self.config.optimizer == "adamax":
            self.optimizer = Adamax(
                optimizer_parameters,
                self.config.learning_rate,
                warmup=self.config.warmup,
                t_total=num_train_step,
                max_grad_norm=self.config.grad_clipping,
                schedule=self.config.warmup_schedule,
                weight_decay=self.config.weight_decay,
            )

        elif self.config.optimizer == "radam":
            self.optimizer = RAdam(
                optimizer_parameters,
                self.config.learning_rate,
                warmup=self.config.warmup,
                t_total=num_train_step,
                max_grad_norm=self.config.grad_clipping,
                schedule=self.config.warmup_schedule,
                eps=self.config.adam_eps,
                weight_decay=self.config.weight_decay,
            )

            # The current radam does not support FP16.
            self.config.fp16 = False
        elif self.config.optimizer == "adam":
            self.optimizer = Adam(
                optimizer_parameters,
                lr=self.config.learning_rate,
                warmup=self.config.warmup,
                t_total=num_train_step,
                max_grad_norm=self.config.grad_clipping,
                schedule=self.config.warmup_schedule,
                weight_decay=self.config.weight_decay,
            )

        else:
            raise RuntimeError(
                f"Unsupported optimizer: {self.config.optimizer}")

        # Clear scheduler for certain optimizer choices
        if self.config.optimizer in ["adam", "adamax", "radam"]:
            if self.config.have_lr_scheduler:
                self.config.have_lr_scheduler = False

        if state_dict and "optimizer" in state_dict:
            self.optimizer.load_state_dict(state_dict["optimizer"])

        if self.config.fp16:
            try:
                global amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            model, optimizer = amp.initialize(
                self.network,
                self.optimizer,
                opt_level=self.config.fp16_opt_level)
            self.network = model
            self.optimizer = optimizer

        if self.config.have_lr_scheduler:
            if self.config.scheduler_type == "rop":
                self.scheduler = ReduceLROnPlateau(self.optimizer,
                                                   mode="max",
                                                   factor=self.config.lr_gamma,
                                                   patience=3)
            elif self.config.scheduler_type == "exp":
                self.scheduler = ExponentialLR(self.optimizer,
                                               gamma=self.config.lr_gamma
                                               or 0.95)
            else:
                milestones = [
                    int(step) for step in (
                        self.config.multi_step_lr or "10,20,30").split(",")
                ]
                self.scheduler = MultiStepLR(self.optimizer,
                                             milestones=milestones,
                                             gamma=self.config.lr_gamma)
        else:
            self.scheduler = None

    def _setup_lossmap(self):
        self.task_loss_criterion = []
        for idx, cs in enumerate(self.config.loss_types):
            assert cs is not None, "Loss type must be defined."
            lc = LOSS_REGISTRY[cs](name=f"Loss func of task {idx}: {cs}")
            self.task_loss_criterion.append(lc)

    def _setup_kd_lossmap(self):
        loss_types = self.config.kd_loss_types
        self.kd_task_loss_criterion = []
        if config.mkd_opt > 0:
            for idx, cs in enumerate(loss_types):
                assert cs, "Loss type must be defined."
                lc = LOSS_REGISTRY[cs](
                    name="Loss func of task {}: {}".format(idx, cs))
                self.kd_task_loss_criterion.append(lc)

    def _to_cuda(self, tensor):
        # Set tensor to gpu (non-blocking) if a PyTorch tensor
        if tensor is None:
            return tensor

        if isinstance(tensor, list) or isinstance(tensor, tuple):
            y = [
                e.cuda(device=self.config.cuda_device, non_blocking=True)
                for e in tensor
            ]
            for t in y:
                t.requires_grad = False
        else:
            y = tensor.cuda(device=self.config.cuda_device, non_blocking=True)
            y.requires_grad = False
        return y

    def train(self):
        if self.para_swapped:
            self.para_swapped = False

    def update(self, batch_meta, batch_data):
        self.network.train()
        target = batch_data[batch_meta["label"]]
        soft_labels = None

        task_type = batch_meta["task_type"]
        target = self._to_cuda(target) if self.config.cuda else target

        task_id = batch_meta["task_id"]
        inputs = batch_data[:batch_meta["input_len"]]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        weight = None
        if self.config.weighted_on:
            if self.config.cuda:
                weight = batch_data[batch_meta["factor"]].cuda(
                    device=self.config.cuda_device, non_blocking=True)
            else:
                weight = batch_data[batch_meta["factor"]]
        logits = self.mnetwork(*inputs)

        # compute loss
        loss = 0
        if self.task_loss_criterion[task_id] and (target is not None):
            loss = self.task_loss_criterion[task_id](logits,
                                                     target,
                                                     weight,
                                                     ignore_index=-1)

        # compute kd loss
        if self.config.mkd_opt > 0 and ("soft_label" in batch_meta):
            soft_labels = batch_meta["soft_label"]
            soft_labels = (self._to_cuda(soft_labels)
                           if self.config.cuda else soft_labels)
            kd_lc = self.kd_task_loss_criterion[task_id]
            kd_loss = (kd_lc(logits, soft_labels, weight, ignore_index=-1)
                       if kd_lc else 0)
            loss = loss + kd_loss

        self.train_loss.update(loss.item(),
                               batch_data[batch_meta["token_id"]].size(0))
        # scale loss
        loss = loss / (self.config.grad_accumulation_step or 1)
        if self.config.fp16:
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()
        self.local_updates += 1
        if self.local_updates % self.config.grad_accumulation_step == 0:
            if self.config.global_grad_clipping > 0:
                if self.config.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(self.optimizer),
                        self.config.global_grad_clipping,
                    )
                else:
                    torch.nn.utils.clip_grad_norm_(
                        self.network.parameters(),
                        self.config.global_grad_clipping)
            self.updates += 1
            # reset number of the grad accumulation
            self.optimizer.step()
            self.optimizer.zero_grad()

    def eval_mode(
        self,
        data: DataLoader,
        metric_meta,
        use_cuda=True,
        with_label=True,
        label_mapper=None,
        task_type=TaskType.Classification,
    ):
        if use_cuda:
            self.cuda()
        predictions = []
        golds = []
        scores = []
        ids = []
        metrics = {}
        for idx, (batch_info, batch_data) in enumerate(data):
            if idx % 100 == 0:
                logger.info(f"predicting {idx}")
            batch_info, batch_data = MTDNNCollater.patch_data(
                use_cuda, batch_info, batch_data)
            score, pred, gold = self._predict_batch(batch_info, batch_data)
            predictions.extend(pred)
            golds.extend(gold)
            scores.extend(score)
            ids.extend(batch_info["uids"])

        if task_type == TaskType.Span:
            golds = merge_answers(ids, golds)
            predictions, scores = select_answers(ids, predictions, scores)
        if with_label:
            metrics = calc_metrics(metric_meta, golds, predictions, scores,
                                   label_mapper)
        return metrics, predictions, scores, golds, ids

    def _predict_batch(self, batch_meta, batch_data):
        self.network.eval()
        task_id = batch_meta["task_id"]
        task_type = batch_meta["task_type"]
        inputs = batch_data[:batch_meta["input_len"]]
        if len(inputs) == 3:
            inputs.append(None)
            inputs.append(None)
        inputs.append(task_id)
        score = self.mnetwork(*inputs)
        if task_type == TaskType.Ranking:
            score = score.contiguous().view(-1, batch_meta["pairwise_size"])
            assert task_type == TaskType.Ranking
            score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.zeros(score.shape, dtype=int)
            positive = np.argmax(score, axis=1)
            for idx, pos in enumerate(positive):
                predict[idx, pos] = 1
            predict = predict.reshape(-1).tolist()
            score = score.reshape(-1).tolist()
            return score, predict, batch_meta["true_label"]
        elif task_type == TaskType.SequenceLabeling:
            mask = batch_data[batch_meta["mask"]]
            score = score.contiguous()
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).reshape(mask.size()).tolist()
            valied_lenght = mask.sum(1).tolist()
            final_predict = []
            for idx, p in enumerate(predict):
                final_predict.append(p[:valied_lenght[idx]])
            score = score.reshape(-1).tolist()
            return score, final_predict, batch_meta["label"]
        elif task_type == TaskType.Span:
            start, end = score
            predictions = []
            if self.config.encoder_type == EncoderModelType.BERT:
                scores, predictions = extract_answer(
                    batch_meta,
                    batch_data,
                    start,
                    end,
                    self.config.get("max_answer_len", 5),
                )
            return scores, predictions, batch_meta["answer"]
        else:
            if task_type == TaskType.Classification:
                score = F.softmax(score, dim=1)
            score = score.data.cpu()
            score = score.numpy()
            predict = np.argmax(score, axis=1).tolist()
            score = score.reshape(-1).tolist()
        return score, predict, batch_meta["label"]

    def fit(self, epochs=0):
        """ Fit model to training datasets """
        epochs = epochs or self.config.epochs
        logger.info(f"Total number of params: {self.total_param}")
        for epoch in range(epochs):
            logger.info(f"At epoch {epoch}")
            logger.info(
                f"Amount of data to go over: {len(self.multitask_train_dataloader)}"
            )

            start = datetime.now()
            # Create batches and train
            for idx, (batch_meta, batch_data) in enumerate(
                    self.multitask_train_dataloader):
                batch_meta, batch_data = MTDNNCollater.patch_data(
                    self.config.cuda, batch_meta, batch_data)

                task_id = batch_meta["task_id"]
                self.update(batch_meta, batch_data)
                if (self.local_updates == 1 or (self.local_updates) %
                    (self.config.log_per_updates *
                     self.config.grad_accumulation_step) == 0):

                    time_left = str((datetime.now() - start) / (idx + 1) *
                                    (len(self.multitask_train_dataloader) -
                                     idx - 1)).split(".")[0]
                    logger.info(
                        "Task - [{0:2}] Updates - [{1:6}] Training Loss - [{2:.5f}] Time Remaining - [{3}]"
                        .format(
                            task_id,
                            self.updates,
                            self.train_loss.avg,
                            time_left,
                        ))
                    if self.config.use_tensor_board:
                        self.tensor_board.add_scalar(
                            "train/loss",
                            self.train_loss.avg,
                            global_step=self.updates,
                        )

                if self.config.save_per_updates_on and (
                    (self.local_updates) %
                    (self.config.save_per_updates *
                     self.config.grad_accumulation_step) == 0):
                    model_file = os.path.join(
                        self.output_dir,
                        "model_{}_{}.pt".format(epoch, self.updates),
                    )
                    logger.info(f"Saving mt-dnn model to {model_file}")
                    self.save(model_file)

            # TODO: Alternatively, we need to refactor save function
            # and move into prediction
            # Saving each checkpoint after model training
            model_file = os.path.join(self.output_dir,
                                      "model_{}.pt".format(epoch))
            logger.info(f"Saving mt-dnn model to {model_file}")
            self.save(model_file)

    def predict(self,
                trained_model_chckpt: str = None,
                saved_epoch_idx: int = 0):
        """ 
        Inference of model on test datasets
        """

        # Load a trained checkpoint if a valid model checkpoint
        if trained_model_chckpt and os.path.exists(trained_model_chckpt):
            logger.info(f"Running predictions using: {trained_model_chckpt}")
            self.load(trained_model_chckpt)

        # Create batches and train
        start = datetime.now()
        for idx, dataset in enumerate(self.test_datasets_list):
            prefix = dataset.split("_")[0]
            label_dict = self.task_defs.global_map.get(prefix, None)
            dev_data: DataLoader = self.dev_dataloaders_list[idx]
            if dev_data is not None:
                with torch.no_grad():
                    (
                        dev_metrics,
                        dev_predictions,
                        scores,
                        golds,
                        dev_ids,
                    ) = self.eval_mode(
                        dev_data,
                        metric_meta=self.task_defs.metric_meta_map[prefix],
                        use_cuda=self.config.cuda,
                        label_mapper=label_dict,
                        task_type=self.task_defs.task_type_map[prefix],
                    )
                for key, val in dev_metrics.items():
                    if self.config.use_tensor_board:
                        self.tensor_board.add_scalar(
                            f"dev/{dataset}/{key}",
                            val,
                            global_step=saved_epoch_idx)
                    if isinstance(val, str):
                        logger.info(
                            f"Task {dataset} -- epoch {saved_epoch_idx} -- Dev {key}:\n {val}"
                        )
                    else:
                        logger.info(
                            f"Task {dataset} -- epoch {saved_epoch_idx} -- Dev {key}: {val:.3f}"
                        )
                score_file = os.path.join(
                    self.output_dir,
                    f"{dataset}_dev_scores_{saved_epoch_idx}.json")
                results = {
                    "metrics": dev_metrics,
                    "predictions": dev_predictions,
                    "uids": dev_ids,
                    "scores": scores,
                }

                # Save results to file
                MTDNNCommonUtils.dump(score_file, results)
                if self.config.use_glue_format:
                    official_score_file = os.path.join(
                        self.output_dir,
                        "{}_dev_scores_{}.tsv".format(dataset,
                                                      saved_epoch_idx),
                    )
                    submit(official_score_file, results, label_dict)

            # test eval
            test_data: DataLoader = self.test_dataloaders_list[idx]
            if test_data is not None:
                with torch.no_grad():
                    (
                        test_metrics,
                        test_predictions,
                        scores,
                        golds,
                        test_ids,
                    ) = self.eval_mode(
                        test_data,
                        metric_meta=self.task_defs.metric_meta_map[prefix],
                        use_cuda=self.config.cuda,
                        with_label=False,
                        label_mapper=label_dict,
                        task_type=self.task_defs.task_type_map[prefix],
                    )
                score_file = os.path.join(
                    self.output_dir,
                    f"{dataset}_test_scores_{saved_epoch_idx}.json")
                results = {
                    "metrics": test_metrics,
                    "predictions": test_predictions,
                    "uids": test_ids,
                    "scores": scores,
                }
                MTDNNCommonUtils.dump(score_file, results)
                if self.config.use_glue_format:
                    official_score_file = os.path.join(
                        self.output_dir,
                        f"{dataset}_test_scores_{saved_epoch_idx}.tsv")
                    submit(official_score_file, results, label_dict)
                logger.info("[new test scores saved.]")

        # Close tensorboard connection if opened
        self.close_connections()

    def close_connections(self):
        # Close tensor board connection
        if self.config.use_tensor_board:
            self.tensor_board.close()

    def extract(self, batch_meta, batch_data):
        self.network.eval()
        # 'token_id': 0; 'segment_id': 1; 'mask': 2
        inputs = batch_data[:3]
        all_encoder_layers, pooled_output = self.mnetwork.bert(*inputs)
        return all_encoder_layers, pooled_output

    def save(self, filename):
        network_state = dict([(k, v.cpu())
                              for k, v in self.network.state_dict().items()])
        params = {
            "state": network_state,
            "optimizer": self.optimizer.state_dict(),
            "config": self.config,
        }
        torch.save(params, filename)
        logger.info("model saved to {}".format(filename))

    def load(self, checkpoint):
        model_state_dict = torch.load(checkpoint)
        self.network.load_state_dict(model_state_dict["state"], strict=False)
        self.optimizer.load_state_dict(model_state_dict["optimizer"])
        self.config = model_state_dict["config"]

    def cuda(self):
        self.network.cuda(device=self.config.cuda_device)

    def supported_init_checkpoints(self):
        """List of allowed check points
        """
        return [
            "bert-base-uncased",
            "bert-base-cased",
            "bert-large-uncased",
            "mtdnn-base-uncased",
            "mtdnn-large-uncased",
            "roberta.base",
            "roberta.large",
        ]

    def update_config_with_training_opts(
        self,
        decoder_opts,
        task_types,
        dropout_list,
        loss_types,
        kd_loss_types,
        tasks_nclass_list,
    ):
        # Update configurations with options obtained from preprocessing training data
        setattr(self.config, "decoder_opts", decoder_opts)
        setattr(self.config, "task_types", task_types)
        setattr(self.config, "tasks_dropout_p", dropout_list)
        setattr(self.config, "loss_types", loss_types)
        setattr(self.config, "kd_loss_types", kd_loss_types)
        setattr(self.config, "tasks_nclass_list", tasks_nclass_list)
Exemple #12
0
def train(args):

    label_name = ['fake', 'real']

    device = torch.device("cuda:0" if args['CUDA'] == 'gpu' else "cpu")

    prefix = args['MODEL'] + '_' + args['BERT_CONFIG']

    bert_size = args['BERT_CONFIG'].split('-')[1]

    start_time = time.time()
    print('Importing data...', file=sys.stderr)
    df_train = pd.read_csv(args['--train'], index_col=0)
    df_val = pd.read_csv(args['--dev'], index_col=0)

    train_label = dict(df_train.information_label.value_counts())

    print("Train label", train_label)

    label_max = float(max(train_label.values()))

    print("Label max", label_max)

    train_label_weight = torch.tensor(
        [label_max / train_label[i] for i in range(len(train_label))],
        device=device)

    print(train_label_weight)

    print('Done! time elapsed %.2f sec' % (time.time() - start_time),
          file=sys.stderr)
    print('-' * 80, file=sys.stderr)

    start_time = time.time()
    print('Set up model...', file=sys.stderr)

    if args['MODEL'] == 'cnn':
        model = CustomBertConvModel(args['BERT_CONFIG'],
                                    device,
                                    float(args['--dropout']),
                                    len(label_name),
                                    out_channel=int(args['--out-channel']))
        optimizer = BertAdam([{
            'params': model.bert.parameters()
        }, {
            'params': model.conv.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.hidden_to_softmax.parameters(),
            'lr': float(args['--lr'])
        }],
                             lr=float(args['--lr-bert']),
                             max_grad_norm=float(args['--clip-grad']))
    elif args['MODEL'] == 'lstm':
        model = CustomBertLSTMModel(args['BERT_CONFIG'],
                                    device,
                                    float(args['--dropout']),
                                    len(label_name),
                                    lstm_hidden_size=int(
                                        args['--hidden-size']))

        optimizer = BertAdam([{
            'params': model.bert.parameters()
        }, {
            'params': model.lstm.parameters(),
            'lr': float(args['--lr'])
        }, {
            'params': model.hidden_to_softmax.parameters(),
            'lr': float(args['--lr'])
        }],
                             lr=float(args['--lr-bert']),
                             max_grad_norm=float(args['--clip-grad']))
    else:
        print('please input valid model')
        exit(0)

    model = model.to(device)
    print('Use device: %s' % device, file=sys.stderr)
    print('Done! time elapsed %.2f sec' % (time.time() - start_time),
          file=sys.stderr)
    print('-' * 80, file=sys.stderr)

    model.train()

    cn_loss = torch.nn.CrossEntropyLoss(weight=train_label_weight,
                                        reduction='mean')
    torch.save(cn_loss, 'loss_func')  # for later testing

    train_batch_size = int(args['--batch-size'])
    valid_niter = int(args['--valid-niter'])
    log_every = int(args['--log-every'])
    model_save_path = prefix + '_model.bin'

    num_trial = 0
    train_iter = patience = cum_loss = report_loss = 0
    cum_examples = report_examples = epoch = 0
    hist_valid_scores = []
    train_time = begin_time = time.time()
    print('Begin Maximum Likelihood training...')

    while True:
        epoch += 1

        for sents, targets in batch_iter(df_train,
                                         batch_size=train_batch_size,
                                         shuffle=True,
                                         bert=bert_size):  # for each epoch
            train_iter += 1  # increase training iteration
            # set gradients to zero before starting to do backpropagation.
            # Pytorch accummulates the gradients on subsequnt backward passes.
            optimizer.zero_grad()
            batch_size = len(sents)
            pre_softmax = model(sents).double()

            loss = cn_loss(
                pre_softmax,
                torch.tensor(targets, dtype=torch.long, device=device))
            # The gradients are "stored" by the tensors themselves once you call backwards
            # on the loss.
            loss.backward()
            '''
             After computing the gradients for all tensors in the model, calling optimizer.step() makes the optimizer iterate over 
             all parameters (tensors) it is supposed to update and use their internally stored grad to update their values.
            '''
            optimizer.step()

            # loss.item() contains the loss for the mini-batch, but divided by the batch_size
            # that's why multiply by the batch_size
            batch_losses_val = loss.item() * batch_size
            report_loss += batch_losses_val
            cum_loss += batch_losses_val

            report_examples += batch_size
            cum_examples += batch_size

            if train_iter % log_every == 0:
                print('epoch %d, iter %d, avg. loss %.2f, '
                      'cum. examples %d, speed %.2f examples/sec, '
                      'time elapsed %.2f sec' %
                      (epoch, train_iter, report_loss / report_examples,
                       cum_examples, report_examples /
                       (time.time() - train_time), time.time() - begin_time),
                      file=sys.stderr)

                train_time = time.time()
                report_loss = report_examples = 0.

            # perform validation
            if train_iter % valid_niter == 0:
                print(
                    'epoch %d, iter %d, cum. loss %.2f, cum. examples %d' %
                    (epoch, train_iter, cum_loss / cum_examples, cum_examples),
                    file=sys.stderr)

                cum_loss = cum_examples = 0

                print('begin validation....', file=sys.stderr)

                validation_loss = validation(
                    model, df_val, bert_size, cn_loss,
                    device)  # dev batch size can be a bit larger

                print('validation: iter %d, loss %f' %
                      (train_iter, validation_loss),
                      file=sys.stderr)

                is_better = len(
                    hist_valid_scores
                ) == 0 or validation_loss < min(hist_valid_scores)
                hist_valid_scores.append(validation_loss)

                if is_better:
                    patience = 0
                    print('save currently the best model to [%s]' %
                          model_save_path,
                          file=sys.stderr)

                    model.save(model_save_path)

                    # also save the optimizers' state
                    torch.save(optimizer.state_dict(),
                               model_save_path + '.optim')

                elif patience < int(args['--patience']):
                    patience += 1
                    print('hit patience %d' % patience, file=sys.stderr)

                    if patience == int(args['--patience']):
                        num_trial += 1
                        print('hit #%d trial' % num_trial, file=sys.stderr)
                        if num_trial == int(args['--max-num-trial']):
                            print('early stop!', file=sys.stderr)
                            exit(0)

                        # decay lr, and restore from previously best checkpoint
                        print(
                            'load previously best model and decay learning rate to %f%%'
                            % (float(args['--lr-decay']) * 100),
                            file=sys.stderr)

                        # load model
                        params = torch.load(
                            model_save_path,
                            map_location=lambda storage, loc: storage)
                        model.load_state_dict(params['state_dict'])
                        model = model.to(device)

                        print('restore parameters of the optimizers',
                              file=sys.stderr)
                        optimizer.load_state_dict(
                            torch.load(model_save_path + '.optim'))

                        # set new lr
                        for param_group in optimizer.param_groups:
                            param_group['lr'] *= float(args['--lr-decay'])

                        # reset patience
                        patience = 0

                if epoch == int(args['--max-epoch']):
                    print('reached maximum number of epochs!', file=sys.stderr)
                    exit(0)
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--mode",
        type=str,
        required=True,
        choices=['train', 'eval'],
        help="Specifies whether to run in train or eval (inference) mode.")

    parser.add_argument("--bert_model",
                        type=str,
                        required=True,
                        choices=['bbu', 'blu', 'bbmu'],
                        help="Bert pre-trained model to be used.")

    parser.add_argument(
        "--expt",
        type=str,
        required=True,
        choices=['aux', 'sbjn', 'nth'],
        help="Type of experiment being conducted. Used to generate the labels."
    )

    # Other parameters
    parser.add_argument("--batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training and inference.")

    parser.add_argument("--n_epochs",
                        default=8,
                        type=int,
                        help="Total number of training epochs to perform.")

    parser.add_argument("--nth_n",
                        default=None,
                        type=int,
                        help="Value of n (ONE-INDEXED) in nth experiments.")

    parser.add_argument(
        "--output_layers",
        default=[-1],
        nargs='+',
        type=int,
        help=
        "Space-separated list of (1 or more) layer indices whose embeddings will be used to train classifiers."
    )

    parser.add_argument(
        "--eager_eval",
        default=False,
        action='store_true',
        help=
        "Whether to run full evaluation (w/ generalization set) after each training epoch."
    )

    parser.add_argument(
        "--load_checkpt",
        default=None,
        type=str,
        help="Path to a checkpoint to be loaded, for training or inference.")

    parser.add_argument(
        "--data_path",
        default='data',
        type=str,
        help="Relative directory where train/test data is stored.")

    parser.add_argument(
        "--expt_path",
        default=None,
        type=str,
        help=
        "Relative directory to store all data associated with current experiment."
    )

    args = parser.parse_args()
    bert_model = args.bert_model
    batch_size = args.batch_size
    n_epochs = args.n_epochs
    data_path = Path(args.data_path)
    tokenizer = tokenization.BertTokenizer.from_pretrained(
        MODEL_ABBREV[bert_model], do_lower_case=True)

    global TRAIN_FILE, TEST_FILE, GEN_FILE
    TRAIN_FILE = args.expt + TRAIN_FILE
    TEST_FILE = args.expt + TEST_FILE
    GEN_FILE = args.expt + GEN_FILE

    # error check
    if args.mode == 'eval' and args.load_checkpt is None:
        raise Exception(
            f"{__file__}: error: the following arguments are required in eval mode: --load-checkpt"
        )
    if args.expt == 'nth' and args.nth_n is None:
        raise Exception(
            f"{__file__}: error: the following arguments are required in nth expts: --nth_n"
        )

    # experiment dir
    Path('experiments').mkdir(exist_ok=True)
    if args.expt_path is None:
        expt_path = Path("experiments/{0}_{1}".format(
            args.expt, time.strftime("%Y%m%d-%H%M%S")))
    else:
        expt_path = Path(args.expt_path)
    expt_path.mkdir(exist_ok=True)

    # cuda
    print('Initializing...')
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    n_gpu = torch.cuda.device_count()

    # if processed data doesn't exist, create it
    for filename in (TRAIN_FILE, TEST_FILE, GEN_FILE):
        if (expt_path / f"{filename}.{bert_model}.proc").exists():
            print(f"Found {expt_path}/{filename}.{bert_model}.proc")
        elif (data_path / filename).exists():
            process_data(data_path / filename,
                         expt_path / f"{filename}.{bert_model}.proc",
                         tokenizer, args.expt, args.nth_n)
        else:
            raise FileNotFoundError(
                f"{data_path / filename} not found! Download from https://github.com/tommccoy1/subj_aux/tree/master/data"
            )

    # load processed data
    train_path, test_path, gen_path = (expt_path /
                                       f"{filename}.{bert_model}.proc"
                                       for filename in (TRAIN_FILE, TEST_FILE,
                                                        GEN_FILE))
    with train_path.open() as train, test_path.open() as test, gen_path.open(
    ) as gen:
        xy_train = ([int(tok) for tok in str.split(line)] for line in train)
        x_train, y_train = zip(*((line[:-1], line[-1]) for line in xy_train))
        xy_test = ([int(tok) for tok in str.split(line)] for line in test)
        x_test, y_test = zip(*((line[:-1], line[-1]) for line in xy_test))
        xy_gen = ([int(tok) for tok in str.split(line)] for line in gen)
        x_gen, y_gen = zip(*((line[:-1], line[-1]) for line in xy_gen))
    n_train, n_test, n_gen = len(x_train), len(x_test), len(x_gen)

    # initialize BERT model
    model = BertForTokenClassificationLayered.from_pretrained(
        MODEL_ABBREV[bert_model], output_layers=args.output_layers)
    model.to(device)

    # distribute model over GPUs, if available
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer: freeze all BERT weights
    for name, param in model.named_parameters():
        param.requires_grad = bool('classifier' in name)
    param_optimizer = [
        p for p in model.named_parameters() if 'classifier' in p[0]
    ]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    # optimizer
    optimizer = BertAdam(optimizer_grouped_parameters, lr=0.001)

    # load from past checkpoint if given
    if args.load_checkpt is not None:
        print(f"Loading from {args.load_checkpt}...")
        checkpt = torch.load(Path(args.load_checkpt))
        epoch = checkpt['epoch']
        model_dict = model.state_dict()
        model_dict.update(checkpt['model_state_dict'])
        model.load_state_dict(model_dict)
        optimizer.load_state_dict(checkpt['optimizer_state_dict'])
    else:
        epoch = 0

    if args.mode == 'train':
        train_msg = f"""Starting training...
            - # GPUs available: {n_gpu}
            - BERT model: {MODEL_ABBREV[bert_model]}
            - Experiment: {args.expt}{', n=' + str(args.nth_n) if args.nth_n is not None else ''}
            - Batch size: {batch_size}
            - # training epochs: {n_epochs}
            - Output layers: {args.output_layers}
        """
        print(train_msg)

        layered_best_valid_loss = [float('inf') for _ in args.output_layers]
        for _ in range(n_epochs):
            epoch += 1
            model.train()
            layered_running_loss = [0. for _ in args.output_layers]
            for batch_num in range(1, n_train // batch_size + 1):
                with torch.no_grad():
                    idx = random.choices(range(n_train), k=batch_size)
                    x_batch = [x_train[i] for i in idx]
                    y_batch = [y_train[j] for j in idx]
                    x_batch, _, y_onehot, mask = prep_batch(
                        x_batch, y_batch, device)

                optimizer.zero_grad()
                layered_loss_batch = model(x_batch,
                                           attention_mask=mask,
                                           labels=y_onehot)
                for idx, loss_batch in enumerate(layered_loss_batch):
                    loss = loss_batch.sum()
                    loss.backward()
                    layered_running_loss[idx] += loss
                optimizer.step()

                if batch_num % 500 == 0:
                    for output_layer, running_loss in zip(
                            args.output_layers, layered_running_loss):
                        layer_str = f"Layer[{output_layer}]" if args.output_layers != [
                            -1
                        ] else ''
                        print(
                            f"Epoch {epoch} ({batch_num * batch_size}/{n_train})",
                            layer_str,
                            f"loss: {running_loss / (batch_num  * batch_size):.7f}",
                            flush=True)

            # compute validation loss and accuracy
            layered_valid_loss, layered_valid_acc, layered_valid_results = run_eval(
                model,
                x_test,
                y_test,
                tokenizer,
                device,
                batch_size=batch_size,
                check_results=True)

            print(f"END of epoch {epoch}, saving checkpoint...")

            for idx, output_layer, running_loss, valid_loss, valid_acc, valid_results in zip(
                    count(), args.output_layers, layered_running_loss,
                    layered_valid_loss, layered_valid_acc,
                    layered_valid_results):
                layer_str = f"Layer[{output_layer}]" if args.output_layers != [
                    -1
                ] else ''
                valid_msg = f"Validation loss: {valid_loss:.7f}, acc: {valid_acc:.7f}"
                layer_idx_str = f"[{output_layer}]" if args.output_layers != [
                    -1
                ] else ''
                results_str = f"{bert_model}{layer_idx_str}.results"
                checkpt_str = f"{bert_model}{layer_idx_str}_e{epoch}_l{valid_loss:8f}.ckpt"

                print(layer_str, valid_msg, flush=True)
                with (data_path / f"{TEST_FILE}").open() as test:
                    test_raw = test.readlines()
                # write results to file
                with (expt_path /
                      f"{TEST_FILE}.{results_str}").open(mode='w') as f_valid:
                    for t_idx, result, pred_subword, pred_idx, true_subword, true_idx in valid_results:
                        f_valid.write(
                            f"#{t_idx}\t{result}\tPrediction: {pred_subword} ({pred_idx})\tTrue: {true_subword} ({true_idx})\t{test_raw[t_idx]}\n"
                        )

                torch.save(
                    {
                        'epoch': epoch,
                        'model_state_dict': {
                            param: value
                            for param, value in model.state_dict().items()
                            if 'classifier' in param
                        },
                        'optimizer_state_dict': optimizer.state_dict(),
                        'train_loss': layered_running_loss,
                        'valid_loss': layered_valid_loss,
                    }, expt_path / checkpt_str)

                if valid_loss < layered_best_valid_loss[idx]:
                    layered_best_valid_loss[idx] = valid_loss
                    print(layer_str,
                          f"best model is {checkpt_str}",
                          flush=True)

                    best_str = f"{bert_model}{layer_idx_str}_best.ckpt"
                    torch.save(
                        {
                            'epoch': epoch,
                            'model_state_dict': {
                                param: value
                                for param, value in model.state_dict().items()
                                if 'classifier' in param
                            },
                            'optimizer_state_dict': optimizer.state_dict(),
                            'train_loss': layered_running_loss,
                            'valid_loss': layered_valid_loss,
                        }, expt_path / best_str)
            if args.eager_eval:
                # compute gen loss and accuracy
                layered_gen_loss, layered_gen_acc, layered_gen_results = run_eval(
                    model,
                    x_gen,
                    y_gen,
                    tokenizer,
                    device,
                    batch_size=batch_size,
                    check_results=True)

                for idx, output_layer, gen_loss, gen_acc, gen_results in zip(
                        count(), args.output_layers, layered_gen_loss,
                        layered_gen_acc, layered_gen_results):
                    layer_str = f"Layer[{output_layer}]" if args.output_layers != [
                        -1
                    ] else ''
                    gen_msg = f"Generalization loss: {gen_loss:.7f}, acc: {gen_acc:.7f}"
                    layer_idx_str = f"[{output_layer}]" if args.output_layers != [
                        -1
                    ] else ''
                    results_str = f"{bert_model}{layer_idx_str}.results"

                    print(layer_str, gen_msg, flush=True)
                    with (data_path / f"{GEN_FILE}").open() as gen:
                        gen_raw = gen.readlines()
                    # write results to file
                    with (expt_path /
                          f"{GEN_FILE}.{results_str}").open(mode='w') as f_gen:
                        for g_idx, result, pred_subword, pred_idx, true_subword, true_idx in gen_results:
                            f_gen.write(
                                f"#{g_idx}\t{result}\tPrediction: {pred_subword} ({pred_idx})\tTrue: {true_subword} ({true_idx})\t{gen_raw[g_idx]}\n"
                            )

    elif args.mode == 'eval':

        eval_msg = f"""Starting evaluation...
            - # GPUs available: {n_gpu}
            - Checkpoint: {args.load_checkpt}
            - BERT model: {MODEL_ABBREV[bert_model]}
            - Experiment: {args.expt}{', n=' + str(args.nth_n) if args.nth_n is not     None else ''}
            - Batch size: {batch_size}
            - # training epochs: {n_epochs}
            - Output layers: {args.output_layers}
        """
        print(eval_msg)

        test_raw_path, gen_raw_path = (data_path / f"{filename}"
                                       for filename in (TEST_FILE, GEN_FILE))
        with test_raw_path.open() as test, gen_raw_path.open() as gen:
            test_raw = test.readlines()
            gen_raw = gen.readlines()

        # compute validation loss and accuracy for test and gen
        layered_valid_loss, layered_valid_acc, layered_valid_results = run_eval(
            model, x_test, y_test, tokenizer, device, check_results=True)
        layered_gen_loss, layered_gen_acc, layered_gen_results = run_eval(
            model, x_gen, y_gen, tokenizer, device, check_results=True)

        for output_layer, valid_loss, valid_acc, valid_results, gen_loss, gen_acc, gen_results in zip(
                args.output_layers, layered_valid_loss, layered_valid_acc,
                layered_valid_results, layered_gen_loss, layered_gen_acc,
                layered_gen_results):
            layer_str = f"Layer[{output_layer}]" if args.output_layers != [
                -1
            ] else ''
            layer_idx_str = f"[{output_layer}]" if args.output_layers != [
                -1
            ] else ''
            valid_msg = f"Validation loss: {valid_loss:.7f}, acc: {valid_acc:.7f}"
            gen_msg = f"Generalization loss: {gen_loss:.7f}, acc: {gen_acc:.7f}"
            results_str = f"{bert_model}{layer_idx_str}.results"

            for msg in (valid_msg, gen_msg):
                print(layer_str, msg)

            # write results to file
            with (expt_path / f"{TEST_FILE}.{results_str}").open(
                    mode='w') as f_valid, (expt_path /
                                           f"{GEN_FILE}.{results_str}").open(
                                               mode='w') as f_gen:
                for idx, result, pred_subword, pred_idx, true_subword, true_idx in valid_results:
                    f_valid.write(
                        f"#{idx}\t{result}\tPrediction: {pred_subword} ({pred_idx})\tTrue: {true_subword} ({true_idx})\t{test_raw[idx]}\n"
                    )

                for idx, result, pred_subword, pred_idx, true_subword, true_idx in gen_results:
                    f_gen.write(
                        f"#{idx}\t{result}\tPrediction: {pred_subword} ({pred_idx})\tTrue: {true_subword} ({true_idx})\t{gen_raw[idx]}\n"
                    )

    else:
        raise Exception('--mode must be set to either "train" or "eval"')