Esempio n. 1
0
def get_categorial_loss(attrs, loss, at, at_loss):
    if loss == 'cross_entropy':
        loss_fns = {}
        for attr in attrs:
            loss_fns[attr] = {}
            loss_fns[attr]['attr'] = binary_cn
            if attr.rec_trainable:
                loss_fns[attr]['rec'] = binary_cn
            if at:
                if at_loss == 'MSE':
                    loss_fns[attr]['at_loss'] = MSELoss()
                else:
                    loss_fns[attr]['at_loss'] = KLDivLoss()
        return loss_fns
    elif loss == 'cross_entropy_weight':
        # return F.cross_entropy, F.cross_entropy
        loss_fns = {}
        weights = get_categorial_weight()
        for attr in attrs:
            loss_fns[attr] = {}
            loss_fns[attr]['attr'] = partial(binary_cn, weight=weights[attr.key][0])
            if attr.rec_trainable:
                loss_fns[attr]['rec'] = partial(binary_cn, weight=weights[attr.key][1])
            if at:
                if at_loss == 'MSE':
                    loss_fns[attr]['at_loss'] = MSELoss()
                else:
                    loss_fns[attr]['at_loss'] = KLDivLoss()
        return loss_fns
    elif loss == 'ohem':
        loss_fns = {}
        for attr in attrs:
            loss_fns[attr] = {}
            loss_fns[attr]['attr'] = ohem_loss
            if attr.rec_trainable:
                loss_fns[attr]['rec'] = reverse_ohem_loss
            if at:
                if at_loss == 'MSE':
                    loss_fns[attr]['at_loss'] = MSELoss()
                else:
                    loss_fns[attr]['at_loss'] = KLDivLoss()
        return loss_fns
        # return Ohem, ohem_loss
    elif loss == 'focal':
        loss_fns = {}
        weights = get_categorial_weight()
        for attr in attrs:
            loss_fns[attr] = {}
            loss_fns[attr]['attr'] = partial(focal_loss, alpha=weights[attr.key][0] / (weights[attr.key][0] + 1))
            if attr.rec_trainable:
                loss_fns[attr]['rec'] = partial(focal_loss, alpha=weights[attr.key][1] / (weights[attr.key][1] + 1))
            if at:
                if at_loss == 'MSE':
                    loss_fns[attr]['at_loss'] = MSELoss()
                else:
                    loss_fns[attr]['at_loss'] = KLDivLoss()
        return loss_fns
        # return focal_loss, focal_loss
    else:
        raise Exception("Loss '{}' is not supported".format(loss))
Esempio n. 2
0
def run_batch_generation_for_latentCopy(args, model, batch):
    batch = tuple(input_tensor.to(args.device) for input_tensor in batch)
    input_ids, token_type_ids, lm_labels, input_masks, input_masks_with_knowledge, knowledgeROIs = batch
    ori_model = model.module if hasattr(model, "module") else model
    ori_model.model_stage = 0
    model_outputs = model(input_ids=input_ids,
                          token_type_ids=None,
                          labels=lm_labels,
                          attention_mask=None)
    z, z_distribution = model_outputs[:2]
    ori_model.model_stage = 1
    model_outputs = model(input_ids=input_ids,
                          token_type_ids=None,
                          labels=lm_labels,
                          attention_mask=input_masks)
    z_prior, z_prior_distribution = model_outputs[:2]
    ori_model.model_stage = 2
    model_outputs = model(input_ids=input_ids,
                          token_type_ids=None,
                          labels=lm_labels,
                          attention_mask=input_masks_with_knowledge,
                          z_hidden_embeds=z,
                          knowledgeROIs=knowledgeROIs)
    KLDiv_Loss = KLDivLoss(reduction='batchmean')
    kld_loss = KLDiv_Loss(z_prior_distribution.log(), z_distribution) if getattr(args, "latent_modify", '') != 'real' \
      else KLDiv_Loss(z_distribution.log(), z_prior_distribution)

    lm_loss, bow_loss, norm_loss, lm_logits = model_outputs[:4]
    return lm_loss, lm_logits, (bow_loss, norm_loss), kld_loss
    def __init__(self, size=0, padding_idx=0, smoothing=0.0):
        super(LabelSmoothing, self).__init__()

        self.criterion = KLDivLoss(size_average=False, reduce=False)
        self.confidence = 1.0 - smoothing
        self.smoothing = smoothing
        self.true_dist = None
Esempio n. 4
0
    def forward( self,input_ids=None,attention_mask=None,token_type_ids=None,position_ids=None,
                 head_mask=None, inputs_embeds=None, labels=None, label_mask=None,):
        """
        :param input_ids: 输入的id
        :param attention_mask:
        :param token_type_ids: segment id
        :param position_ids: 模型使用position_id来识别哪个token在哪个位置
        :param head_mask:
        :param inputs_embeds:
        :param labels:
        :param label_mask:
        :return:
        """
        #首先调用原始的roberta的模型,得到输出, 返回 [last_hidden_states, pooled_output, hidden_states, attentions]
        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        #获取last_hidden_states作为特征
        final_embedding = outputs[0]
        sequence_output = self.dropout(final_embedding)
        logits = self.classifier(sequence_output)
        # 输出为 [logits, final_embedding, hidden_states, attentions]
        outputs = (logits, final_embedding,) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:
            # 训练模式
            # 只计算我们关注的token的损失,
            if attention_mask is not None or label_mask is not None:
                active_loss = True
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                if label_mask is not None:
                    active_loss = active_loss & label_mask.view(-1)
                # 取到未mask的logits,进行下一步计算损失
                active_logits = logits.view(-1, self.num_labels)[active_loss]
            # 判断形状相同, eg: labels.shape torch.Size([16, 128])  [batch_size,seq_length]   logits.shape: torch.Size([16, 128, 11])  [batch_size,seq_length, num_class]
            if labels.shape == logits.shape:
                #散度损失. 有mask,就用mask的计算损失,否则计算所有损失
                loss_fct = KLDivLoss()
                if attention_mask is not None or label_mask is not None:
                    active_labels = labels.view(-1, self.num_labels)[active_loss]
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits, labels)
            else:
                loss_fct = CrossEntropyLoss()
                if attention_mask is not None or label_mask is not None:
                    active_labels = labels.view(-1)[active_loss]
                    # 一个批次计算损失 eg: active_logits.shape: torch.Size([485, 11]),  active_labels.shape: 485
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
            #把损失也加入到outputs
            outputs = (loss,) + outputs

        return outputs  # (loss), logits, final_embedding, (hidden_states), (attentions)
def child_adience_ldl_loss(model_out, gt):
    dist, minor, adience = model_out
    ldl, minor_gt, adience_gt = gt
    lf = NLLLoss(reduction='mean')
    kl = KLDivLoss(reduction='batchmean')

    return kl(dist, ldl) + lf(minor, minor_gt) + lf(adience, adience_gt)
def main(config):
    model = MODELS[config.model](backbone=config.backbone, n_channels=config.channel,
                                 n_classes=config.classes)
    model_cam = MODELS['deeplabdilate2d_cam'](backbone='resnet34', n_channels=config.channel,
                                              n_classes=config.classes)
    if int(config.backbone[6:]) > 34:
        in_channels = 2048 // 2 + 2048 // 4 + 2048 // 8
    else:
        in_channels = 512 // 2 + 512 // 4 + 512 // 8
    latent_discriminator = DISCRIMINATOR['latent_discriminator'](in_channels=in_channels, filters=64,
                                                                 backbone=config.backbone)
    seg_discriminator = DISCRIMINATOR['discriminator'](in_channels=config.classes, filters=64)
    criterion = {
        'loss': CrossEntropyLoss(),
        'bceloss': BCEWithLogitsLoss(),
        'wbceloss': WeightedBCEWithLogitsLoss(),
        'klloss': KLDivLoss(),
        'emloss': EMLoss(),
        'entropyloss': EntropyLoss(),
        'celoss': CrossEntropyLoss()
    }
    seg_help = DASEGHelper(model, criterion,
                           config)
    optimizer = seg_help.reset_optim()
    latent_optimizer = OPTIM['adam'](
        params=filter(lambda p: p.requires_grad, latent_discriminator.parameters()),
        lr=seg_help.config.learning_rate_d, betas=(seg_help.config.beta_1, seg_help.config.beta_2))
    seg_optimizer = OPTIM['adam'](
        params=filter(lambda p: p.requires_grad, seg_discriminator.parameters()),
        lr=seg_help.config.learning_rate_d, betas=(seg_help.config.beta_1, seg_help.config.beta_2))
    seg_help.move_to_cuda()
    try:
        model_cam = seg_help.load_pretrained_cam_seg_model(model_cam)
    except FileExistsError as e:
        raise ValueError('file not exist')

    if seg_help.use_cuda:
        latent_discriminator.to(seg_help.equipment)
        seg_discriminator.to(seg_help.equipment)
        model_cam.to(seg_help.equipment)
        if len(seg_help.config.gpu_count) > 1 and seg_help.config.train:
            latent_discriminator = torch.nn.DataParallel(latent_discriminator, device_ids=seg_help.config.gpu_count)
            seg_discriminator = torch.nn.DataParallel(seg_discriminator, device_ids=seg_help.config.gpu_count)
            model_cam = torch.nn.DataParallel(model_cam, device_ids=seg_help.config.gpu_count)

    # optimizer, epoch_start = seg_help.load_hist_model_optim(optimizer)
    print("data name ", seg_help.config.data_name)
    train_loader, _ = seg_help.get_covid_infection_seg_data_loader_2d_slice(
        data_root=seg_help.config.data_root, pos=0.6)
    unsu_loader = seg_help.get_covid_infection_seg_unsu_data_loader_2d(
        data_root=seg_help.config.unsu_root, pos=0.6)
    train(seg_help, model_cam, train_loader, unsu_loader, latent_discriminator,
          seg_discriminator, optimizer, latent_optimizer, seg_optimizer)

    print("\n-----------load best state of model -----------")
    seg_help.load_best_state()
    seg_help.log.flush()

    seg_help.summary_writer.close()
 def __init__(self, size, padding_idx, smoothing=0.1):
     super(LabelSmoothing, self).__init__()
     self.criterion = KLDivLoss(reduction='sum')
     self.padding_idx = padding_idx
     self.confidence = 1.0 - smoothing
     self.smoothing = smoothing
     self.size = size
     self.true_dist = None
Esempio n. 8
0
def compute_KL_loss(low, med, high, T):
    """
    Compute the KL divergence loss from 3 layers
    :param low: lowest level feature layer
    :param med: middle feature layer
    :param high: highest level feature layer
    :param T: temperature softening parameter
    :return: KL_loss
    """
    low, med, high = F.softmax(low, dim=0), F.softmax(med,
                                                      dim=0), F.softmax(high,
                                                                        dim=0)
    low = apply_clsa_softening(low, T).log()
    med = apply_clsa_softening(med, T).log()

    return (T**2) * (KLDivLoss(reduction='batchmean')(med, high) +
                     KLDivLoss(reduction='batchmean')(low, high))
Esempio n. 9
0
 def __init__(self, student_config, teacher_config, device, args):
     self.mse_loss = MSELoss()
     self.kl_loss = KLDivLoss(reduction='batchmean')
     self.cosine_loss = CosineEmbeddingLoss()
     self.distill_config = student_config.distillation_config
     self.device = device
     self.student_config = student_config
     self.teacher_config = teacher_config
     self.batch_size = args.train_batch_size
Esempio n. 10
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        label_mask=None,
    ):

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        final_embedding = outputs[0]
        sequence_output = self.dropout(final_embedding)
        logits = self.classifier(sequence_output)

        outputs = (logits, final_embedding, ) + outputs[2:]  # add hidden states and attention if they are here
        if labels is not None:

            # Only keep active parts of the loss
            if attention_mask is not None or label_mask is not None:
                active_loss = True
                if attention_mask is not None:
                    active_loss = attention_mask.view(-1) == 1
                if label_mask is not None:
                    active_loss = active_loss & label_mask.view(-1)
                active_logits = logits.view(-1, self.num_labels)[active_loss]


            if labels.shape == logits.shape:
                loss_fct = KLDivLoss()
                if attention_mask is not None or label_mask is not None:
                    active_labels = labels.view(-1, self.num_labels)[active_loss]
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits, labels)
            else:
                loss_fct = CrossEntropyLoss()
                if attention_mask is not None or label_mask is not None:
                    active_labels = labels.view(-1)[active_loss]
                    loss = loss_fct(active_logits, active_labels)
                else:
                    loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))


            outputs = (loss,) + outputs

        return outputs  # (loss), scores, final_embedding, (hidden_states), (attentions)
def mean_ldl_loss(model_out, gt):
    dist, age = gt

    kl = KLDivLoss(reduction='batchmean')
    #sml = SmoothL1Loss(reduction='mean')

    #ages = torch.matmul(torch.exp(model_out), settings.CLASSES).view(-1)

    return kl(model_out, dist)  # + sml(ages, age)
def main(config):
    model = MODELS[config.model](backbone=config.backbone,
                                 n_channels=config.channel,
                                 n_classes=config.classes)
    model_cam = MODELS['deeplabdilate2d_cam'](backbone='resnet34',
                                              n_channels=config.channel,
                                              n_classes=config.classes)
    model_seg = MODELS['deeplabdilate2d_camv19'](backbone=config.backbone,
                                                 n_channels=config.channel,
                                                 n_classes=config.classes)
    criterion = {
        'loss': CrossEntropyLoss(),
        'bceloss': BCEWithLogitsLoss(),
        'wbceloss': WeightedBCEWithLogitsLoss(),
        'klloss': KLDivLoss(),
        'emloss': EMLoss(),
        'entropyloss': EntropyLoss(),
        'celoss': CrossEntropyLoss()
    }
    seg_help = DASEGHelper(model, criterion, config)
    optimizer = seg_help.reset_optim()
    seg_help.move_to_cuda()
    try:
        model_cam = seg_help.load_pretrained_cam_seg_model(model_cam)
    except FileExistsError as e:
        raise ValueError('file not exist')
    try:
        model_seg = seg_help.load_pretrained_da_seg_model(model_seg)
    except FileExistsError as e:
        raise ValueError('file not exist')
    if len(seg_help.config.gpu_count) > 1:
        seg_help.model.module.load_state_dict(model_seg.state_dict())
    else:
        seg_help.model.load_state_dict(model_seg.state_dict())
    model_seg.eval()
    if seg_help.use_cuda:
        model_cam.to(seg_help.equipment)
        model_seg.to(seg_help.equipment)
        if len(seg_help.config.gpu_count) > 1 and seg_help.config.train:
            model_cam = torch.nn.DataParallel(
                model_cam, device_ids=seg_help.config.gpu_count)
            model_seg = torch.nn.DataParallel(
                model_seg, device_ids=seg_help.config.gpu_count)

    # optimizer, epoch_start = seg_help.load_hist_model_optim(optimizer)
    print("data name ", seg_help.config.data_name)
    train_loader, _ = seg_help.get_covid_infection_seg_data_loader_2d_slice(
        data_root=seg_help.config.data_root, pos=0.6)
    unsu_loader = seg_help.get_covid_infection_seg_unsu_data_loader_2d(
        data_root=seg_help.config.unsu_root, pos=0.6)
    train(seg_help, model_cam, model_seg, train_loader, unsu_loader, optimizer)
    seg_help.log.flush()

    seg_help.summary_writer.close()
Esempio n. 13
0
 def __init__(self, size, padding_idx, smoothing=0.0):
     """
         @param: size, 分类的类别大小,
         @param: padding_idx, pad特殊标记对应的字典下标
     """
     super(LabelSmoothing, self).__init__()
     self.criterion = KLDivLoss(size_average=False)
     self.padding_idx = padding_idx
     self.confidence = 1.0 - smoothing
     self.smoothing = smoothing
     self.size = size
     self.true_dist = None
Esempio n. 14
0
    def __init__(self,
                 nclass=-1,
                 weight=None,
                 size_average=True,
                 ignore_index=-1):
        super(SegmentationMultiLosses, self).__init__(weight, size_average,
                                                      ignore_index)
        self.nclass = nclass

        self.eps = 0.001
        self.l_log_softmax = Softmax(dim=1)
        self.l_kl = KLDivLoss(reduction="none")
 def __init__(self, model, criterion, X, Y, SMLoss_mode=0):
     super(BaseTrainer, self).__init__()
     weight_id = torch.cat([torch.ones(751),torch.zeros(1)])
     weight_part = torch.FloatTensor([1.5,1,0.5,0.5,1,1.5])
     weight_part = torch.FloatTensor([1.,1.,1.,1.,1.,1.])
     self.model = model
     self.criterion = criterion
     self.criterion_part = nn.CrossEntropyLoss().cuda()
     self.criterion_ID = nn.CrossEntropyLoss(weight = weight_id).cuda()
     self.indx=X
     self.indy=Y
     self.SML_mode=SMLoss_mode
     self.KLoss = KLDivLoss()
Esempio n. 16
0
	def __init__(self, reduction: str = 'mean', log_activation: Module = LogSoftmax(dim=-1)):
		"""
			Jensen-Shannon Divergence loss with logits.

			Use the following formula :

			>>> 'JS(p,q) = 0.5 * (KL(LS(p),m) + KL(LS(q),m)), with m = LS(0.5 * (p+q))'
			>>> 'where LS = LogSoftmax and KL = KL-Divergence.'

			:param reduction: The reduction function to apply. (default: 'mean')
			:param log_activation: The log-activation function for compute predictions from logits. (default: LogSoftmax(dim=-1))
		"""
		super().__init__()
		self.kl_div = KLDivLoss(reduction=reduction, log_target=True)
		self.log_activation = log_activation
Esempio n. 17
0
    def kl_loss(S: Tensor) -> Tensor:
        r"""The additional KL divergence-based loss

        .. math::
            P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K
            S_{i,k}^2 / \sum_{n=1}^N S_{n,k}}

            \mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert
            \mathbf{S})
        """
        S_2 = S**2
        P = S_2 / S.sum(dim=1, keepdim=True)
        P = P / (S_2.sum(dim=2, keepdim=True) / S.sum(dim=1, keepdim=True))
        P[torch.isnan(P)] = 0.

        loss = KLDivLoss(reduction='batchmean', log_target=False)
        return loss(S.log(), P)
Esempio n. 18
0
    def kl_loss(S: Tensor) -> Tensor:
        r"""The additional KL divergence-based loss

        .. math::
            P_{i,j} &= \frac{S_{i,j}^2 / \sum_{n=1}^N S_{n,j}}{\sum_{k=1}^K
            S_{i,k}^2 / \sum_{n=1}^N S_{n,k}}

            \mathcal{L}_{\textrm{KL}} &= \textrm{KLDiv}(\mathbf{P} \Vert
            \mathbf{S})
        """
        S_2 = S**2
        P = S_2 / S.sum(dim=1, keepdim=True)
        denom = P.sum(dim=2, keepdim=True)
        denom[S.sum(dim=2, keepdim=True) == 0.0] = 1.0
        P /= denom

        loss = KLDivLoss(reduction='batchmean', log_target=False)
        return loss(S.clamp(EPS).log(), P.clamp(EPS))
Esempio n. 19
0
 def __init__(self,
              model,
              criterion,
              X,
              Y,
              SMLoss_mode=0,
              Triplet_margin=1.0):
     super(BaseTrainer, self).__init__()
     weight_id = torch.cat([torch.ones(751), torch.zeros(1)])
     self.model = model
     self.criterion = criterion
     self.criterion_part = nn.CrossEntropyLoss().cuda()
     self.criterion_ID = nn.CrossEntropyLoss(weight=weight_id).cuda()
     self.criterion_tri = PartialTripletLoss(margin=Triplet_margin).cuda()
     #        self.criterion_tri = TripletLoss(margin=Triplet_margin).cuda()
     self.indx = X
     self.indy = Y
     self.SML_mode = SMLoss_mode
     self.KLoss = KLDivLoss()
def stacked_ldl_loss(model_out, epoch, consistency_rampup):

    base_vid_loss = KLDivLoss(reduction='mean').to(settings.DEVICE)

    with torch.no_grad():
        video_ages = torch.matmul(torch.exp(model_out),
                                  settings.CLASSES).view(-1)
        means = torch.tensor(
            list(map(torch.mean, video_ages.split(settings.FRAMES_PER_VID))))
        stds = torch.tensor(
            list(map(torch.std, video_ages.split(settings.FRAMES_PER_VID))))
        norms = Normal(means, stds)
        a = torch.tensor(range(0, 101)).reshape(-1, 1).to(torch.float)
        labels = torch.exp(norms.log_prob(a)).t().to(settings.DEVICE)
        target = labels.repeat(settings.FRAMES_PER_VID, 1)

    w = get_current_consistency_weight(epoch,
                                       consistency_rampup=consistency_rampup)
    return base_vid_loss(model_out, target) * w
Esempio n. 21
0
    def __init__(self,
                 action_space,
                 observation_space,
                 primary='q',
                 gamma=0.99,
                 adp_delta=0.01,
                 adp_bins=7,
                 mutual_steps=1000,
                 do_target_q=False,
                 q_target_lag=100,
                 model_lag=100,
                 initial_epsilon=1.0,
                 final_epsilon=0.01,
                 epsilon_decay_steps=5000):
        self._mutual_steps = mutual_steps
        self._mutual_loss_fn = KLDivLoss(reduction='sum')
        self._steps = 0
        self._adp = ADP(action_space=action_space,
                        observation_space=observation_space,
                        bins=adp_bins,
                        gamma=gamma,
                        delta=adp_delta)
        self._q = QLearner(action_space=action_space,
                           observation_space=observation_space,
                           Q='simple',
                           opt_args={'lr': 0.01},
                           gamma=gamma,
                           memory_len=1000,
                           target_lag=q_target_lag,
                           initial_epsilon=initial_epsilon,
                           final_epsilon=final_epsilon,
                           exploration_steps=epsilon_decay_steps)

        self.model_lag = model_lag

        if primary == 'q':
            self._primary = self._q
        elif primary == 'adp':
            self._primary = self._adp
        else:
            raise Exception('Invalid option')

        self.disagreement_losses = []
Esempio n. 22
0
    def __init__(self, softmax_dimension: int, padding_token: int,
                 smoothing_factor: float) -> None:
        super(LabelSmoothedLoss, self).__init__()
        self.softmax_dimension = softmax_dimension
        self.padding_token = padding_token
        # factor of which:
        self.smoothing_factor = smoothing_factor
        # factor of which:
        self.confidence = 1.0 - smoothing_factor
        # fraction of redistributed probability assigned to each one of the
        # non-target tokens in the vocabulary (padding token excluded)
        self.redistributed_probability_each = smoothing_factor /\
            (softmax_dimension - 2)

        # loss criterion - requiring inputs as log-probabilities:
        self.loss_criterion = KLDivLoss(reduction='sum', log_target=False)
        # predictions expected as log-probabilities, labels as probabilities

        # initialization of label distributions:
        self.smoothed_tgt_distributions = None
    def train(self, train_tuple, eval_tuple):
        dset, loader, evaluator = train_tuple
        iter_wrapper = (lambda x: tqdm(x, total=len(loader))
                        ) if args.tqdm else (lambda x: x)

        best_valid = 0.
        optim_steps = 0
        for epoch in range(args.epochs):
            quesid2ans = {}
            for i, (ques_id, feats, boxes, sent, target, iou_question, iou_answer, sem_question_words, sem_answer_words, bboxes_words,)\
                 in iter_wrapper(enumerate(loader)):

                self.model.train()
                self.optim.zero_grad()

                # DEBUG: print pointer (set batch size to 1)
                # print(dset.id2datum[ques_id[0]]['sent'])
                # print(dset.id2datum[ques_id[0]]['label'])
                # q_pointer = dset.id2datum[ques_id[0]]['pointer']['question']
                # for w_index in q_pointer:
                #     print(w_index)

                feats, boxes, target = feats.cuda(), boxes.cuda(), target.cuda(
                )
                iou_question, iou_answer = iou_question.cuda(
                ), iou_answer.cuda()
                sem_question_words, sem_answer_words, bboxes_words = sem_question_words.cuda(
                ), sem_answer_words.cuda(), bboxes_words.cuda()
                logit, iou_target, iou_score = self.model(
                    feats, boxes, sent, iou_question, iou_answer,
                    sem_question_words, sem_answer_words, bboxes_words)
                assert logit.dim() == target.dim() == 2
                if args.mce_loss:
                    max_value, target = target.max(1)
                    loss = self.mce_loss(logit, target) * logit.size(1)
                else:
                    loss = self.bce_loss(logit, target)
                    loss = loss * logit.size(1)
                #print('CE', loss.item())

                if args.answer_loss == 'glove':
                    gold_glove = (self.labelans2glove.unsqueeze(0) *
                                  target.unsqueeze(-1)).sum(1)
                    #gold_ans = self.train_tuple.dataset.label2ans[target.argmax(dim=1)[0]]
                    #print('gold:', gold_ans)
                    pred_glove = (
                        self.labelans2glove.unsqueeze(0) *
                        torch.softmax(logit, dim=1).unsqueeze(-1)).sum(1)
                    #pred_ans = self.train_tuple.dataset.label2ans[logit.argmax(dim=1)[0]]
                    #print('pred:', pred_ans)
                    sim_answer = self.cosineSim(gold_glove, pred_glove).mean()
                    loss += -10 * sim_answer
                    #print('Similarity', sim_answer)
                    #input(' ')

                if optim_steps % 1000 == 0:
                    self.writerTbrd.add_scalar('vqa_loss_train', loss.item(),
                                               optim_steps)

                # task_pointer = 'KLDiv'
                ALPHA = args.alpha_pointer

                def iou_preprocess(iou, obj_conf=None):
                    TRESHOLD = 0.1
                    TOPK = 3
                    # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0)  #iou / (iou.sum() + 1e-9)
                    # f_iou = norm_iou * (iou.sum() >= TRESHOLD)
                    sorted_values = torch.sort(iou, descending=True, dim=-1)[0]
                    t_top = sorted_values[:, :, TOPK - 1]
                    iou_topk = iou.masked_fill(iou < t_top.unsqueeze(-1), -1e9)
                    f_iou = torch.softmax(iou_topk, dim=-1)
                    treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >=
                                     TRESHOLD).float()
                    if args.task_pointer == 'KLDiv':
                        return f_iou, treshold_mask
                    elif args.task_pointer == 'Triplet':
                        # Remove top10 most similar objects
                        t_bot = sorted_values[:, :, 10]
                        iou_botk = (iou < t_bot.unsqueeze(-1)).float()
                        # Take topk most confident objects
                        conf_top = torch.sort(obj_conf.unsqueeze(1) * iou_botk,
                                              descending=True,
                                              dim=-1)[0][:, :, TOPK - 1]
                        conf_mask = obj_conf.unsqueeze(1).expand(
                            -1, iou.size(1), -1) >= conf_top.unsqueeze(-1)
                        neg_score = iou_botk * conf_mask.float()
                        return f_iou, treshold_mask, neg_score

                if args.task_pointer == 'KLDiv':
                    iou_target_preprocess, treshold_mask = iou_preprocess(
                        iou_target)
                    loss_pointer_fct = KLDivLoss(reduction='none')
                    iou_pred = torch.log_softmax(iou_score, dim=-1)
                    matching_loss = loss_pointer_fct(
                        input=iou_pred, target=iou_target_preprocess)
                    matching_loss = ALPHA * (matching_loss.sum(-1) *
                                             treshold_mask).sum() / (
                                                 (treshold_mask).sum() + 1e-9)
                    if optim_steps % 1000 == 0:
                        self.writerTbrd.add_scalar('pointer_loss_train',
                                                   matching_loss.item(),
                                                   optim_steps)
                    loss += matching_loss

                # ? by Corentin: Matching loss
                # def iou_preprocess(iou):
                #     TRESHOLD = 0.1
                #     TOPK = 1
                #     # norm_iou = np.exp(iou) / np.sum(np.exp(iou), axis=0)  #iou / (iou.sum() + 1e-9)
                #     # f_iou = norm_iou * (iou.sum() >= TRESHOLD)
                #     t = torch.sort(iou, descending=True, dim=-1)[0][:, :, TOPK-1]
                #     iou_topk = iou.masked_fill(iou < t.unsqueeze(-1), -1e9)
                #     f_iou = torch.softmax(iou_topk, dim=-1)
                #     treshold_mask = (iou_topk.clamp(min=.0).sum(-1) >= TRESHOLD).float()
                #     return f_iou, treshold_mask
                # # discard iou_target when total iou is under treshold
                # # it includes unsupervised datum
                # iou_target_preprocess, treshold_mask = iou_preprocess(iou_target)
                # iou_pred = torch.log_softmax(iou_pred, dim=-1)
                # # KL loss
                # matching_loss = []
                # matching_loss = self.KL_loss(input=iou_pred, target=iou_target_preprocess)
                # matching_loss = (matching_loss.sum(-1) * treshold_mask).sum() / treshold_mask.sum()
                # if optim_steps % 1000 == 0:
                #     self.writerTbrd.add_scalar('pointer_loss_train', matching_loss.item(), optim_steps)
                # ALPHA = 5.0
                # loss += ALPHA * matching_loss
                # ? **************************

                loss.backward()
                nn.utils.clip_grad_norm_(self.model.parameters(), 5.)
                self.optim.step()
                optim_steps += 1

                score, label = logit.max(1)
                for qid, l in zip(ques_id, label.cpu().numpy()):
                    ans = dset.label2ans[l]
                    quesid2ans[qid] = ans

                # if self.valid_tuple is not None and optim_steps % 1152 == 0:  # Do Validation
                #     valid_score = self.evaluate(eval_tuple)
                #     fastepoch = int(optim_steps / 1152)
                #     print("fastEpoch %d: Valid %0.2f\n" % (fastepoch, valid_score * 100.,))

            log_str = "\nEpoch %d: Train %0.2f\n" % (
                epoch, evaluator.evaluate(quesid2ans) * 100.)

            if self.valid_tuple is not None:  # Do Validation
                valid_score = self.evaluate(eval_tuple)
                self.writerTbrd.add_scalar('vqa_acc_valid', valid_score, epoch)
                if valid_score > best_valid:
                    best_valid = valid_score
                    self.save("BEST")

                log_str += "Epoch %d: Valid %0.2f\n" % (epoch, valid_score * 100.) + \
                           "Epoch %d: Best %0.2f\n" % (epoch, best_valid * 100.)

            print(log_str, end='')

            with open(self.output + "/log.log", 'a') as f:
                f.write(log_str)
                f.flush()

        self.save("LAST")
Esempio n. 24
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        label_mask=None,
    ):

        outputs = self.roberta(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        final_embedding = outputs[0]
        sequence_output = self.dropout(final_embedding)
        logits = self.classifier(sequence_output)

        outputs = (
            logits,
            final_embedding,
        ) + outputs[2:]  # add hidden states and attention if they are here

        loss_dict = {}
        if labels is not None:
            # logits = self.logsoftmax(logits)
            # Only keep active parts of the loss
            active_loss = True
            if attention_mask is not None:
                active_loss = attention_mask.view(-1) == 1
                # active_loss = True
                # if attention_mask is not None:
                #     active_loss = attention_mask.view(-1) == 1
                # if label_mask is not None:
                #     active_loss = active_loss & label_mask.view(-1)
                # active_logits = logits.view(-1, self.num_labels)[active_loss]

            for key in labels:
                label = labels[key]
                if label is None:
                    continue
                # if key=="pseudo" and label_mask is not None:
                if label_mask is not None:
                    all_active_loss = active_loss & label_mask.view(-1)
                else:
                    all_active_loss = active_loss
                active_logits = logits.view(-1,
                                            self.num_labels)[all_active_loss]

                if label.shape == logits.shape:
                    loss_fct = KLDivLoss()
                    # loss_fct = SoftFocalLoss(gamma=2)
                    if attention_mask is not None or label_mask is not None:
                        active_labels = label.view(
                            -1, self.num_labels)[all_active_loss]
                        loss = loss_fct(active_logits, active_labels)
                    else:
                        loss = loss_fct(logits, label)
                else:
                    loss_fct = CrossEntropyLoss()
                    # loss_fct = FocalLoss(gamma=2)
                    # loss_fct = NLLLoss()
                    if attention_mask is not None or label_mask is not None:
                        active_labels = label.view(-1)[all_active_loss]
                        loss = loss_fct(active_logits, active_labels)
                    else:
                        loss = loss_fct(logits.view(-1, self.num_labels),
                                        label.view(-1))
                loss_dict[key] = loss

            outputs = (loss_dict, ) + outputs

        return outputs  # (loss dict), scores, final_embedding, (hidden_states), (attentions)
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        sent_bounds=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
            Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
            1]``.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        outputs = self.albert(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
        )

        sequence_output = outputs[0]

        sequence_output = self.dropout(sequence_output)

        device = sequence_output.device
        pos_matrix = torch.arange(sequence_output.size()[1], device=device).view(1, 1, -1)
        if_in_sent = torch.logical_and(sent_bounds[:, :, 1].unsqueeze(-1) <= pos_matrix,
                                           pos_matrix <= sent_bounds[:, :, 2].unsqueeze(-1))

        if self.pooling_type == 'average':
            pooling_matrix = torch.where(if_in_sent, torch.tensor((1), device=device), torch.tensor((0), device=device)).float()
            sent_len = torch.sum(pooling_matrix, 2).unsqueeze(2)
            sent_len[sent_len==0] = 1
            pooling_matrix = pooling_matrix / sent_len
            sentence_hiddens = torch.bmm(sequence_output.transpose(-1, -2), pooling_matrix.transpose(-1, -2)).transpose(-1, -2)
        elif self.pooling_type == 'max':
            pooling_matrix = torch.where(if_in_sent.unsqueeze(-1),  sequence_output.unsqueeze(1), torch.tensor((0.0), device=device)).float()
            sentence_hiddens = torch.max(pooling_matrix, dim=2)[0]
        logits = self.output_layer(sentence_hiddens).squeeze(-1)

        mask = torch.where(sent_bounds[:, :, 0] >= 0, torch.tensor(0.0, device=device), torch.tensor((-10000.0), device=device))
        logits += mask

        loss = None
        if labels is not None:
            loss_fct = KLDivLoss()
            # Only keep active parts of the loss
            loss = loss_fct(F.log_softmax(logits, dim=-1), F.softmax(labels, dim=-1))

        if not return_dict:
            output = (logits,) + outputs[2:]
            return ((loss,) + output) if loss is not None else output

        return TokenClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )
Esempio n. 26
0
def train_student(model,
                  teacher_model,
                  tokenizer,
                  params,
                  train_examples,
                  valid_examples=None,
                  name=None,
                  checkpoint_files={
                      'config': 'bert_config.json',
                      'model_weigths': 'model_trained.pth'
                  },
                  temperature=3,
                  alpha=0.9,
                  all_logits_teacher=None):

    if name is not None:
        checkpoint_config = checkpoint_files[
            'config'][:-5] + '_' + name + '.json'
        checkpoint_model_weigths = checkpoint_files[
            'model_weigths'][:-4] + '_' + name + '.pth'
    else:
        checkpoint_config = checkpoint_files['config'][:-5] + '_student.json'
        checkpoint_model_weigths = checkpoint_files[
            'model_weigths'][:-4] + '_student.pth'

    random.seed(params['seed'])
    np.random.seed(params['seed'])
    torch.manual_seed(params['seed'])

    train_steps_per_epoch = int(
        len(train_examples) / params['train_batch_size'])
    num_train_optimization_steps = train_steps_per_epoch * params[
        'num_train_epochs']

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=params['learning_rate'],
                         warmup=params['warmup_proportion'],
                         t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    train_features = feature_processors.convert_examples_to_features(
        train_examples, params['label_list'], params['max_seq_length'],
        tokenizer)

    print("***** Running training *****")
    print("Num examples:", len(train_examples))
    print("Batch size:  ", params['train_batch_size'])
    print("Num steps:   ", num_train_optimization_steps)
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.long)

    if all_logits_teacher is None:
        eval_teacher_dataset = TensorDataset(all_input_ids, all_input_mask,
                                             all_segment_ids, all_label_ids)

        all_logits_teacher = eval_teacher_soft_targets(teacher_model,
                                                       eval_teacher_dataset,
                                                       label_list, params)

    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids, all_logits_teacher)

    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=params['train_batch_size'])

    model.train()
    for epoch_num in range(int(params['num_train_epochs'])):
        print('Epoch: {}'.format(epoch_num + 1))
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        for step, batch in enumerate(
                tqdm_notebook(train_dataloader, desc="Iteration")):
            batch = tuple(t.to(params['device']) for t in batch)
            input_ids, input_mask, segment_ids, label_ids, teacher_logits = batch

            logits_model = model(input_ids, segment_ids, input_mask)

            loss_first = KLDivLoss()(F.log_softmax(logits_model / temperature),
                                     F.softmax(teacher_logits / temperature))
            loss_second = CrossEntropyLoss()(logits_model.view(
                -1, model.num_labels), label_ids.view(-1))
            loss = loss_first * (temperature**
                                 2) * alpha + (1. - alpha) * loss_second
            #             loss = loss_first * alpha + (1. - alpha) * loss_second
            loss.backward()

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            optimizer.step()
            optimizer.zero_grad()
            global_step += 1

        train_result = {
            'train_loss': tr_loss / nb_tr_steps,
            'train_global_step': global_step,
        }
        print(train_result)
        if valid_examples is not None:
            valid_result, valid_prob_preds = evaluate(model, tokenizer, params,
                                                      valid_examples)
            print(valid_result)
            model.train()


#     Save a trained model and the associated configuration
    if not os.path.exists(params['output_dir']):
        os.makedirs(params['output_dir'])

    model_to_save = model.module if hasattr(model, 'module') else model
    output_model_file = os.path.join(params['output_dir'],
                                     checkpoint_model_weigths)
    torch.save(model_to_save.state_dict(), output_model_file)
    #     output_config_file = os.path.join(params['output_dir'], checkpoint_config) #### another file
    #     with open(output_config_file, 'w') as f:
    #         f.write(model_to_save.config.to_json_string())

    #     # Load a trained model and config that you have fine-tuned
    #     config = BertConfig(output_config_file)
    #     model = BertForSequenceClassification(config, num_labels=model.num_labels)
    #     model.load_state_dict(torch.load(output_model_file))
    #     model.to(device)

    result = {
        'train_loss': tr_loss / nb_tr_steps,
        'train_global_step': global_step
    }
    if valid_examples is not None:
        result['eval_loss'] = valid_result['eval_loss']
        result['eval_accuracy'] = valid_result['eval_accuracy']
    return model, result
Esempio n. 27
0
    def forward(self,
                input_ids,
                token_type_ids=None,
                attention_mask=None,
                labels=None,
                label_ids=None,
                position_ids=None,
                head_mask=None):
        """
        Forward pass
        :return: output[0] is loss value,
                 output[1] is log_softmax value,
                 the rest are hidden states and attentions
        """
        # Get BERT output
        outputs = self.bert(input_ids,
                            position_ids=position_ids,
                            token_type_ids=token_type_ids,
                            attention_mask=attention_mask,
                            head_mask=head_mask)
        # Get the value of [CLS]
        x_pool = outputs[1]

        # Get the last hidden-state from BERT output
        # Remove first token [CLS]
        x_sequence = outputs[0]
        x_sequence = x_sequence[:, 1:]
        # Pass the sequence through LSTM layers
        x_sequence = x_sequence.reshape(self.max_seq_length - 1,
                                        self.batch_size, -1)
        lstm_out, self.lstm_hidden = self.lstm(x_sequence, self.lstm_hidden)

        # Concat CLS with LSTM output (last hidden values)
        x = torch.cat((x_pool, lstm_out[-1]), dim=1)

        # Pass the output to
        y = self.dropout(x)
        y = self.hidden2dense(x)
        y = self.relu(y)
        y = self.dropout(y)

        logits = self.classifier(y)
        log_softmax = F.log_softmax(logits)

        # Add log_softmax value to outputs
        outputs = (log_softmax, ) + outputs[2:]

        # Calculate loss
        if labels is not None:
            if self.num_labels == 1:
                # Loss for regression problem (not use in Offensive task)
                loss_fct = MSELoss()
                loss = loss_fct(log_softmax.view(-1), labels.view(-1))
            else:
                # Loss is the combination of loss on both soft and hard labels
                loss_fct_soft = KLDivLoss()
                loss_fct_hard = CrossEntropyLoss()
                loss = (1 - self.soft_label_ratio) * loss_fct_hard(logits.view(-1, self.num_labels), label_ids.view(-1)) \
                       + self.soft_label_ratio * loss_fct_soft(log_softmax[:, 1].view(-1), labels.view(-1))
        else:
            # For inference phase
            loss = 0

        # Add loss to outputs
        outputs = (loss, ) + outputs
        return outputs
Esempio n. 28
0
    def train(self, model_path: Optional[str] = None, trial: Union["optuna.Trial", Dict[str, Any]] = None):
        """
        Main training entry point.
        Args:
            model_path (:obj:`str`, `optional`):
                Local path to the model if the model to train has been instantiated from a local path. If present,
                training will resume from the optimizer/scheduler states loaded here.
            trial (:obj:`optuna.Trial` or :obj:`Dict[str, Any]`, `optional`):
                The trial run or the hyperparameter dictionary for hyperparameter search.
        """
        # This might change the seed so needs to run first.
        self._hp_search_setup(trial)

        # Model re-init
        if self.model_init is not None:
            # Seed must be set before instantiating the model when using model_init.
            set_seed(self.args.seed)
            model = self.model_init()
            self.model = model.to(self.args.device)

            # Reinitializes optimizer and scheduler
            self.optimizer, self.lr_scheduler = None, None

        # Data loader and number of training steps
        train_dataloader = self.get_train_dataloader()
        num_update_steps_per_epoch = len(train_dataloader) // self.args.gradient_accumulation_steps
        num_update_steps_per_epoch = max(num_update_steps_per_epoch, 1)
        if self.args.max_steps > 0:
            t_total = self.args.max_steps
            num_train_epochs = self.args.max_steps // num_update_steps_per_epoch + int(
                self.args.max_steps % num_update_steps_per_epoch > 0
            )
        else:
            t_total = int(num_update_steps_per_epoch * self.args.num_train_epochs)
            num_train_epochs = self.args.num_train_epochs
            self.args.max_steps = t_total

        self.create_optimizer_and_scheduler(num_training_steps=t_total)

        # Check if saved optimizer or scheduler states exist
        if (
            model_path is not None
            and os.path.isfile(os.path.join(model_path, "optimizer.pt"))
            and os.path.isfile(os.path.join(model_path, "scheduler.pt"))
        ):
            # Load in optimizer and scheduler states
            self.optimizer.load_state_dict(
                torch.load(os.path.join(model_path, "optimizer.pt"), map_location=self.args.device)
            )
            self.lr_scheduler.load_state_dict(torch.load(os.path.join(model_path, "scheduler.pt")))

        model = self.model
        if self.args.fp16 and _use_apex:
            if not is_apex_available():
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
            model, self.optimizer = amp.initialize(model, self.optimizer, opt_level=self.args.fp16_opt_level)

        # multi-gpu training (should be after apex fp16 initialization)
        if self.args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        # Distributed training (should be after apex fp16 initialization)
        if self.args.local_rank != -1:
            model = torch.nn.parallel.DistributedDataParallel(
                model,
                device_ids=[self.args.local_rank],
                output_device=self.args.local_rank,
                find_unused_parameters=True,
            )

        if self.tb_writer is not None:
            self.tb_writer.add_text("args", self.args.to_json_string())
            self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})

        # Train!
        if is_torch_tpu_available():
            total_train_batch_size = self.args.train_batch_size * xm.xrt_world_size()
        else:
            total_train_batch_size = (
                self.args.train_batch_size
                * self.args.gradient_accumulation_steps
                * (torch.distributed.get_world_size() if self.args.local_rank != -1 else 1)
            )
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", self.num_examples(train_dataloader))
        logger.info("  Num Epochs = %d", num_train_epochs)
        logger.info("  Instantaneous batch size per device = %d", self.args.per_device_train_batch_size)
        logger.info("  Total train batch size (w. parallel, distributed & accumulation) = %d", total_train_batch_size)
        logger.info("  Gradient Accumulation steps = %d", self.args.gradient_accumulation_steps)
        logger.info("  Total optimization steps = %d", t_total)

        self.global_step = 0
        self.epoch = 0
        epochs_trained = 0
        steps_trained_in_current_epoch = 0
        # Check if continuing training from a checkpoint
        if model_path is not None:
            # set global_step to global_step of last saved checkpoint from model path
            try:
                self.global_step = int(model_path.split("-")[-1].split(os.path.sep)[0])

                epochs_trained = self.global_step // num_update_steps_per_epoch
                steps_trained_in_current_epoch = self.global_step % (num_update_steps_per_epoch)

                logger.info("  Continuing training from checkpoint, will skip to saved global_step")
                logger.info("  Continuing training from epoch %d", epochs_trained)
                logger.info("  Continuing training from global step %d", self.global_step)
                logger.info("  Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch)
            except ValueError:
                self.global_step = 0
                logger.info("  Starting fine-tuning.")

        tr_loss_sum = 0.0
        loss_sum = defaultdict(float)
        best = {self.best_metric: None}
        model.zero_grad()
        disable_tqdm = self.args.disable_tqdm or not self.is_local_process_zero()
        train_pbar = trange(epochs_trained, int(np.ceil(num_train_epochs)), desc="Epoch", disable=disable_tqdm)
        for epoch in range(epochs_trained, int(np.ceil(num_train_epochs))):
            if isinstance(train_dataloader, DataLoader) and isinstance(train_dataloader.sampler, DistributedSampler):
                train_dataloader.sampler.set_epoch(epoch)

            if is_torch_tpu_available():
                parallel_loader = pl.ParallelLoader(train_dataloader, [self.args.device]).per_device_loader(
                    self.args.device
                )
                epoch_iterator = parallel_loader
            else:
                epoch_iterator = train_dataloader

            # Reset the past mems state at the beginning of each epoch if necessary.
            if self.args.past_index >= 0:
                self._past = None

            epoch_pbar = tqdm(epoch_iterator, desc="Iteration", disable=disable_tqdm)
            for step, inputs in enumerate(epoch_iterator):

                # Skip past any already trained steps if resuming training
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    epoch_pbar.update(1)
                    continue

                model.train()
                inputs = self._prepare_inputs(inputs)

                inputs["output_attentions"] = self.length_drop_args.length_config is not None

                layer_config = sample_layer_configuration(
                    model.config.num_hidden_layers,
                    layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                    layer_dropout=0,
                )
                inputs["layer_config"] = layer_config

                inputs["length_config"] = self.length_drop_args.length_config

                outputs = model(**inputs)
                # Save past state if it exists
                if self.args.past_index >= 0:
                    self._past = outputs[self.args.past_index]
                task_loss = self.div_loss(outputs[0])
                if self.length_drop_args.length_adaptive:
                    loss_sum["full"] += task_loss.item()
                loss = task_loss
                if self.length_drop_args.length_adaptive:
                    loss = loss / (self.length_drop_args.num_sandwich + 2)

                tr_loss_sum += loss.item()
                if self.args.fp16 and _use_native_amp:
                    self.scaler.scale(loss).backward()
                elif self.args.fp16 and _use_apex:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                # inplace distillation
                if self.length_drop_args.length_adaptive:
                    logits = outputs[1].detach()

                    for i in range(self.length_drop_args.num_sandwich + 1):
                        inputs["output_attentions"] = True

                        layer_config = sample_layer_configuration(
                            model.config.num_hidden_layers,
                            layer_dropout_prob=self.length_drop_args.layer_dropout_prob,
                            layer_dropout=(self.length_drop_args.layer_dropout_bound if i == 0 else None),
                            layer_dropout_bound=self.length_drop_args.layer_dropout_bound,
                        )
                        inputs["layer_config"] = layer_config

                        length_config = sample_length_configuration(
                            self.args.max_seq_length,
                            model.config.num_hidden_layers,
                            layer_config,
                            length_drop_ratio=(self.length_drop_args.length_drop_ratio_bound if i == 0 else None),
                            length_drop_ratio_bound=self.length_drop_args.length_drop_ratio_bound,
                        )
                        inputs["length_config"] = length_config

                        outputs_sub = model(**inputs)
                        task_loss_sub = self.div_loss(outputs_sub[0])
                        if i == 0:
                            loss_sum["smallest"] += task_loss_sub.item()
                            loss_sum["sub"] += 0
                        else:
                            loss_sum["sub"] += task_loss_sub.item() / self.length_drop_args.num_sandwich

                        logits_sub = outputs_sub[1]
                        loss_fct = KLDivLoss(reduction="batchmean")
                        kl_loss = loss_fct(F.log_softmax(logits, -1), F.softmax(logits_sub, -1))
                        loss = self.div_loss(kl_loss)
                        loss_sum["kl"] += loss.item() / (self.length_drop_args.num_sandwich + 1)
                        loss = loss / (self.length_drop_args.num_sandwich + 2)

                        tr_loss_sum += loss.item()
                        if self.args.fp16 and _use_native_amp:
                            self.scaler.scale(loss).backward()
                        elif self.args.fp16 and _use_apex:
                            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                                scaled_loss.backward()
                        else:
                            loss.backward()

                if (step + 1) % self.args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                    (step + 1) == len(epoch_iterator) <= self.args.gradient_accumulation_steps
                ):
                    if self.args.fp16 and _use_native_amp:
                        self.scaler.unscale_(self.optimizer)
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)
                    elif self.args.fp16 and _use_apex:
                        torch.nn.utils.clip_grad_norm_(amp.master_params(self.optimizer), self.args.max_grad_norm)
                    else:
                        torch.nn.utils.clip_grad_norm_(model.parameters(), self.args.max_grad_norm)

                    if is_torch_tpu_available():
                        xm.optimizer_step(self.optimizer)
                    elif self.args.fp16 and _use_native_amp:
                        self.scaler.step(self.optimizer)
                        self.scaler.update()
                    else:
                        self.optimizer.step()

                    self.lr_scheduler.step()
                    model.zero_grad()
                    self.global_step += 1
                    self.epoch = epoch + (step + 1) / len(epoch_iterator)

                    if (self.args.logging_steps > 0 and self.global_step % self.args.logging_steps == 0) or (
                        self.global_step == 1 and self.args.logging_first_step
                    ):
                        # backward compatibility for pytorch schedulers
                        lr = (
                            self.lr_scheduler.get_last_lr()[0]
                            if version.parse(torch.__version__) >= version.parse("1.4")
                            else self.lr_scheduler.get_lr()[0]
                        )
                        loss = tr_loss_sum / self.args.logging_steps
                        tr_loss_sum = 0.0
                        logs = {"lr": lr, "loss": loss}
                        log_str = f"[{self.global_step:5d}] lr {lr:g} | loss {loss:2.3f}"

                        for key, value in loss_sum.items():
                            value /= self.args.logging_steps
                            loss_sum[key] = 0.0
                            logs[f"{key}_loss"] = value
                            log_str += f" | {key}_loss {value:2.3f}"

                        self.log(logs, "train")
                        logger.info(log_str)

                    '''
                    if (
                        self.args.evaluation_strategy == EvaluationStrategy.STEPS
                        and self.global_step % self.args.eval_steps == 0
                    ):
                        results = self.evaluate()
                        self._report_to_hp_search(trial, epoch, results)
                    '''

                    if self.args.save_steps > 0 and self.global_step % self.args.save_steps == 0:
                        # In all cases (even distributed/parallel), self.model is always a reference
                        # to the model we want to save.
                        if hasattr(model, "module"):
                            assert (
                                model.module is self.model
                            ), f"Module {model.module} should be a reference to self.model"
                        else:
                            assert model is self.model, f"Model {model} should be a reference to self.model"

                        if self.args.evaluate_during_training:
                            results = self.evaluate()
                            results = {k[5:]: v for k, v in results.items() if k.startswith("eval_")}
                            self.log(results, "dev")
                            msg = " | ".join([f"{k} {v:.3f}" for k, v in results.items()])
                            logger.info(f"  [{self.global_step:5d}] {msg}")

                        # Save model checkpoint
                        if self.args.save_only_best:
                            output_dirs = []
                        else:
                            checkpoint_folder = f"{PREFIX_CHECKPOINT_DIR}-{self.global_step}"
                            if self.hp_search_backend is not None and trial is not None:
                                run_id = (
                                    trial.number
                                    if self.hp_search_backend == HPSearchBackend.OPTUNA
                                    else tune.get_trial_id()
                                )
                                checkpoint_folder += f"-run-{run_id}"
                            output_dirs = [os.path.join(self.args.output_dir, checkpoint_folder)]
                            
                        if self.args.evaluate_during_training:
                            if best[self.best_metric] is None or results[self.best_metric] > best[self.best_metric]:
                                logger.info("Congratulations, best model so far!")
                                output_dirs.append(os.path.join(self.args.output_dir, "checkpoint-best"))
                                best = results

                        for output_dir in output_dirs:
                            self.save_model(output_dir)

                            if self.is_world_master() and self.tokenizer is not None:
                                self.tokenizer.save_pretrained(output_dir)

                            if self.is_world_process_zero():
                                self._rotate_checkpoints(use_mtime=True)

                            '''
                            if is_torch_tpu_available():
                                xm.rendezvous("saving_optimizer_states")
                                xm.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                xm.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            elif self.is_world_process_zero():
                                torch.save(self.optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                                torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))
                            '''

                epoch_pbar.update(1)
                if 0 < self.args.max_steps <= self.global_step:
                    break
            epoch_pbar.close()
            train_pbar.update(1)

            '''
            if self.args.evaluation_strategy == EvaluationStrategy.EPOCH:
                results = self.evaluate()
                self._report_to_hp_search(trial, epoch, results)
            '''

            if self.args.tpu_metrics_debug or self.args.debug:
                if is_torch_tpu_available():
                    # tpu-comment: Logging debug metrics for PyTorch/XLA (compile, execute times, ops, etc.)
                    xm.master_print(met.metrics_report())
                else:
                    logger.warning(
                        "You enabled PyTorch/XLA debug metrics but you don't have a TPU "
                        "configured. Check your training configuration if this is unexpected."
                    )
            if 0 < self.args.max_steps <= self.global_step:
                break

        train_pbar.close()
        if self.tb_writer:
            self.tb_writer.close()
        if self.args.past_index and hasattr(self, "_past"):
            # Clean the state at the end of training
            delattr(self, "_past")

        logger.info("\n\nTraining completed. Do not forget to share your model on huggingface.co/models =)\n\n")
        return self.global_step, best
Esempio n. 29
0
    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        soft_labels=None,
    ):
        r"""
        labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`, defaults to :obj:`None`):
            Labels for computing the sequence classification/regression loss.
            Indices should be in :obj:`[0, ..., config.num_labels - 1]`.
            If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
            If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).

    Returns:
        :obj:`tuple(torch.FloatTensor)` comprising various elements depending on the configuration (:class:`~transformers.RobertaConfig`) and inputs:
        loss (:obj:`torch.FloatTensor` of shape :obj:`(1,)`, `optional`, returned when :obj:`label` is provided):
            Classification (or regression if config.num_labels==1) loss.
        logits (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, config.num_labels)`):
            Classification (or regression if config.num_labels==1) scores (before SoftMax).
        hidden_states (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_hidden_states=True``):
            Tuple of :obj:`torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer)
            of shape :obj:`(batch_size, sequence_length, hidden_size)`.

            Hidden-states of the model at the output of each layer plus the initial embedding outputs.
        attentions (:obj:`tuple(torch.FloatTensor)`, `optional`, returned when ``config.output_attentions=True``):
            Tuple of :obj:`torch.FloatTensor` (one for each layer) of shape
            :obj:`(batch_size, num_heads, sequence_length, sequence_length)`.

            Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
            heads.

    Examples::

        from transformers import RobertaTokenizer, RobertaForSequenceClassification
        import torch

        tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
        model = RobertaForSequenceClassification.from_pretrained('roberta-base')
        input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)  # Batch size 1
        labels = torch.tensor([1]).unsqueeze(0)  # Batch size 1
        outputs = model(input_ids, labels=labels)
        loss, logits = outputs[:2]
        """
        outputs = self.filter(
            input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )
        logits = self.output(outputs)

        outputs = (logits, )

        if labels is not None:
            if self.num_labels == 1:
                #  We are doing regression
                loss_fct = MSELoss()
                loss = loss_fct(logits.view(-1), labels.view(-1))
            else:
                loss_fct = CrossEntropyLoss()
                other_lang_logits, eng_logits = logits[:, 0, :].view(
                    -1,
                    self.num_labels), logits[:,
                                             1, :].view(-1, self.num_labels)

                loss_other = loss_fct(other_lang_logits, labels.view(-1))
                loss_eng = loss_fct(eng_logits, labels.view(-1))

                if self.config.alpha > 0:
                    alpha = self.config.alpha
                    T = self.config.temperature
                    loss_KD = KLDivLoss()(
                        F.log_softmax(other_lang_logits / T, dim=1),
                        F.softmax(soft_labels / T, dim=1)) * (T * T)
                    loss = (1. - alpha) * loss_other + alpha * loss_KD
                else:
                    loss = loss_other

                if not self.config.first_loss_only:
                    loss += loss_eng
            outputs = (loss, ) + outputs

        return outputs  # (loss), logits, (hidden_states), (attentions)
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    num_warmup_steps = int(args.warmup_ratio * t_total)

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ["bias", "LayerNorm.weight"]
    retention_params = []
    wd_params = []
    no_wd_params = []
    for n, p in model.named_parameters():
        if "retention" in n:
            retention_params.append(p)
        elif any(nd in n for nd in no_decay):
            no_wd_params.append(p)
        else:
            wd_params.append(p)
    optimizer_grouped_parameters = [{
        "params": wd_params,
        "weight_decay": args.weight_decay,
        "lr": args.learning_rate
    }, {
        "params": no_wd_params,
        "weight_decay": 0.0,
        "lr": args.learning_rate
    }]
    if len(retention_params) > 0:
        optimizer_grouped_parameters.append({
            "params": retention_params,
            "weight_decay": 0.0,
            "lr": args.lr_soft_extract
        })

    optimizer = AdamW(optimizer_grouped_parameters, eps=args.adam_epsilon)
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=t_total)

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(args.model_name_or_path, "scheduler.pt")):
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )

        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        try:
            # set global_step to gobal_step of last saved checkpoint from model path
            checkpoint_suffix = args.model_name_or_path.split("-")[-1].split(
                "/")[0]
            global_step = int(checkpoint_suffix)
            epochs_trained = global_step // (len(train_dataloader) //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                len(train_dataloader) // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except ValueError:
            logger.info("  Starting fine-tuning.")

    tr_loss_sum = 0.0
    loss_sum = defaultdict(float)
    best = {'f1': None}
    model.zero_grad()
    # Added here for reproductibility
    set_seed(args)

    for epoch in range(epochs_trained, int(args.num_train_epochs)):
        for step, batch in enumerate(train_dataloader):
            # Skip past any already trained steps if resuming training
            if steps_trained_in_current_epoch > 0:
                steps_trained_in_current_epoch -= 1
                continue

            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                "input_ids": batch[0],
                "attention_mask": batch[1],
                "token_type_ids": batch[2],
                "start_positions": batch[3],
                "end_positions": batch[4],
            }

            if args.model_type in [
                    "xlm", "roberta", "distilbert", "camembert"
            ]:
                del inputs["token_type_ids"]

            if args.model_type in ["xlnet", "xlm"]:
                inputs.update({"cls_index": batch[5], "p_mask": batch[6]})
                if args.version_2_with_negative:
                    inputs.update({"is_impossible": batch[7]})
                if hasattr(model, "config") and hasattr(
                        model.config, "lang2id"):
                    inputs.update({
                        "langs":
                        (torch.ones(batch[0].shape, dtype=torch.int64) *
                         args.lang_id).to(args.device)
                    })

            inputs["output_attentions"] = args.length_config is not None

            layer_config = sample_layer_configuration(
                model.config.num_hidden_layers,
                layer_dropout_prob=args.layer_dropout_prob,
                layer_dropout=0,
            )
            inputs["layer_config"] = layer_config

            inputs["length_config"] = args.length_config

            outputs = model(**inputs)
            # model outputs are always tuple in transformers (see doc)
            task_loss = div_loss(outputs[0], args)
            loss_sum["full"] += task_loss.item()
            loss = task_loss
            if args.length_adaptive:
                loss = loss / (args.num_sandwich + 2)

            tr_loss_sum += loss.item()
            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            # inplace distillation
            if args.length_adaptive:
                start_logits = outputs[1].detach()
                end_logits = outputs[2].detach()

                for i in range(args.num_sandwich + 1):
                    inputs["output_attentions"] = True

                    layer_config = sample_layer_configuration(
                        model.config.num_hidden_layers,
                        layer_dropout_prob=args.layer_dropout_prob,
                        layer_dropout=(args.layer_dropout_bound
                                       if i == 0 else None),
                        layer_dropout_bound=args.layer_dropout_bound,
                    )
                    inputs["layer_config"] = layer_config

                    length_config = sample_length_configuration(
                        args.max_seq_length,
                        model.config.num_hidden_layers,
                        layer_config,
                        length_drop_ratio=(args.length_drop_ratio_bound
                                           if i == 0 else None),
                        length_drop_ratio_bound=args.length_drop_ratio_bound,
                    )
                    inputs["length_config"] = length_config

                    outputs_sub = model(**inputs)
                    task_loss_sub = div_loss(outputs_sub[0], args)
                    if i == 0:
                        loss_sum["smallest"] += task_loss_sub.item()
                        loss_sum["sub"] += 0
                    else:
                        loss_sum["sub"] += task_loss_sub.item(
                        ) / args.num_sandwich

                    start_logits_sub = outputs_sub[1]
                    end_logits_sub = outputs_sub[2]
                    loss_fct = KLDivLoss(reduction="batchmean")
                    start_kl_loss = loss_fct(F.log_softmax(start_logits, -1),
                                             F.softmax(start_logits_sub, -1))
                    end_kl_loss = loss_fct(F.log_softmax(end_logits, -1),
                                           F.softmax(end_logits_sub, -1))
                    loss = div_loss((start_kl_loss + end_kl_loss) / 2, args)
                    loss_sum["kl"] += loss.item() / (args.num_sandwich + 1)
                    loss = loss / (args.num_sandwich + 2)

                    tr_loss_sum += loss.item()
                    if args.fp16:
                        with amp.scale_loss(loss, optimizer) as scaled_loss:
                            scaled_loss.backward()
                    else:
                        loss.backward()

            if (step + 1) % args.gradient_accumulation_steps == 0 or (
                    # last step in epoch but step is always smaller than gradient_accumulation_steps
                (step + 1) == len(train_dataloader) <=
                    args.gradient_accumulation_steps):
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                # Log metrics
                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    lr = scheduler.get_lr()[0]
                    loss = tr_loss_sum / args.logging_steps
                    tr_loss_sum = 0.0
                    logs = {"lr": lr, "loss": loss}
                    log_str = f"[{global_step:5d}] lr {lr:g} | loss {loss:2.3f}"

                    for key, value in loss_sum.items():
                        value /= args.logging_steps
                        loss_sum[key] = 0.0
                        logs[f"{key}_loss"] = value
                        log_str += f" | {key}_loss {value:2.3f}"

                    for k, v in logs.items():
                        tb_writer.add_scalar(k, v, global_step)
                    logger.info(log_str)

                # Save model checkpoint
                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Only evaluate when single GPU otherwise metrics may not average well
                    if args.local_rank == -1 and args.evaluate_during_training:
                        results, eval_time = evaluate(args,
                                                      model,
                                                      tokenizer,
                                                      prefix="")
                        for key, value in results.items():
                            tb_writer.add_scalar("eval_{}".format(key), value,
                                                 global_step)
                        msg = " | ".join(
                            [f"{k} {v:.2f}" for k, v in results.items()])
                        logger.info(f"  [{global_step:5d}] {msg}")

                    if args.save_only_best:
                        output_dirs = []
                    else:
                        output_dirs = [
                            os.path.join(args.output_dir,
                                         f"checkpoint-{global_step}")
                        ]
                    if args.evaluate_during_training and (
                            best['f1'] is None or results['f1'] > best['f1']):
                        logger.info("Congratulations, best model so far!")
                        output_dirs.append(
                            os.path.join(args.output_dir, "checkpoint-best"))
                        best = {
                            'step': global_step,
                            'f1': results['f1'],
                            'em': results['exact'],
                            'eval_time': eval_time
                        }

                    for output_dir in output_dirs:
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)
                        # Take care of distributed/parallel training
                        model_to_save = model.module if hasattr(
                            model, "module") else model
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                        torch.save(
                            args, os.path.join(output_dir,
                                               "training_args.bin"))
                        # torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt"))
                        # torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt"))

            if 0 < args.max_steps <= global_step:
                break

        if 0 < args.max_steps <= global_step:
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, best