コード例 #1
0
    def train(self, epochs, pretrain_file=None):
        logging.info(
            "%s INFO: Begin training",
            time.strftime("%m/%d/%Y %I:%M:%S %p", time.localtime()),
        )

        iter_ = 0

        start_epoch, accu, iou, f1, train_loss, test_loss, losses = self._load_init(
            pretrain_file
        )
        loss_weights = torch.ones(
            self.cfg.N_CLASSES, dtype=torch.float32, device=self.device
        )
        if self.cfg.WEIGHTED_LOSS or self.cfg.REVOLVER_WEIGHTED:
            weights = self.gt_dataset.compute_frequency()
            if self.cfg.REVOLVER_WEIGHTED:
                self.train_dataset.set_sparsifier_weights(weights)
            if self.cfg.WEIGHTED_LOSS:
                loss_weights = (
                    torch.from_numpy(weights).type(torch.FloatTensor).to(self.device)
                )

        train_loader = self.train_dataset.get_loader(
            self.cfg.BATCH_SIZE, self.cfg.WORKERS
        )
        for e in tqdm(range(start_epoch, epochs + 1), total=epochs + 1 - start_epoch):
            logging.info(
                "\n%s Epoch %s",
                time.strftime("%m/%d/%Y %I:%M:%S %p", time.localtime()),
                e,
            )
            self.net.train()
            steps_pbar = tqdm(
                train_loader, total=self.cfg.EPOCH_SIZE // self.cfg.BATCH_SIZE
            )
            for data in steps_pbar:
                features, labels = data
                self.optimizer.zero_grad()
                features = features.float().to(self.device)
                labels = labels.float().to(self.device)
                output = self.net(features)
                loss = CrossEntropyLoss(loss_weights)(output, labels.long())
                loss.backward()
                self.optimizer.step()
                losses.append(loss.item())
                iter_ += 1
                steps_pbar.set_postfix({"loss": loss.item()})
            train_loss.append(np.mean(losses[-1 * self.cfg.EPOCH_SIZE :]))
            loss, iou_, acc_, f1_ = self.test()
            test_loss.append(loss)
            accu.append(acc_)
            iou.append(iou_ * 100)
            f1.append(f1_ * 100)
            del (loss, iou_, acc_)
            if e % 5 == 0:
                self._save_net(e, accu, iou, f1, train_loss, test_loss, losses)
            self.scheduler.step()
        # Save final state
        self._save_net(epochs, accu, iou, f1, train_loss, test_loss, losses, False)
コード例 #2
0
    def test_forward(self):
        # Arrange
        k = 2
        pad_index = -1

        predicted = torch.tensor([[[0.5, .5], [0.5, .6], [1.0, 0.0],
                                   [0.0, 1.0], [0.0, 1.0]],
                                  [[0.5, .5], [0.5, .6], [1.0, 0.0],
                                   [0.0, 1.0], [0.0, 1.0]]])
        target = torch.tensor([[pad_index, 0, 1, 0, pad_index],
                               [pad_index, 0, 1, 0, pad_index]])

        non_pad_indices = torch.tensor([1, 2, 3])

        expected_loss = CrossEntropyLoss()(predicted[:,
                                                     non_pad_indices].permute(
                                                         0, 2, 1),
                                           target[:, non_pad_indices])

        sut = NerCrossEntropyLoss(pad_index_label=pad_index)

        # Act
        actual = sut.forward(predicted.permute(0, 2, 1), target)

        # Assert
        self.assertEqual(round(expected_loss.item(), 2),
                         round(actual.item(), 2))
コード例 #3
0
def train(model, optimizer, dataset, args):

    model = model.to(device)
    optimizer = opt(optimizer, model, args)
    # scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.training_steps)
    
    model.train()

    loss_hist = []
    acc_hist = []
    accuracy = 0

    print("Begin Training...")
    for epoch in range(args.epochs):
        print(f"Epoch {epoch}")
        ctime = time.time()
        for data_e, (inp, tar) in enumerate(dataset):
            tar = tar.to(device)
            input_ids, token_type_ids, attention_mask = inp["input_ids"].to(device), \
                                                        inp["token_type_ids"].to(device), \
                                                        inp["attention_mask"].to(device)
            
            output = model(input_ids=input_ids,
                           token_type_ids=token_type_ids,
                           attention_mask=attention_mask,
                           cls_pos=args.cls_pos)

            loss = CrossEntropyLoss()(output, tar)

            loss = loss / args.gradient_accumulation_steps
            loss.backward()
            
            accuracy += (tar == output.argmax(1)).type(torch.float).mean() / args.gradient_accumulation_steps

            if not data_e % args.gradient_accumulation_steps:
                loss_hist.append(loss.item())
                optimizer.step()
                # scheduler.step()
                optimizer.zero_grad()
                acc_hist.append(accuracy)
                accuracy = 0

            if not data_e % 2000:
                print(f"Batch {data_e} Loss : {loss.item()}")
                print(f"Ground Truth: {tar.tolist()} \t Predicted: {output.argmax(1).tolist()}")
        
        print(
            f"Time taken for epoch{epoch+1} : {round( (time.time() - ctime) / 60, 2 )} MINUTES"
        )
        torch.save(model, args.pretrainedPATH + f"saved_checkpoint_{args.save_checkpoint}.pt")
        # model.save_pretrained(
        #     args.pretrainedPATH + f"saved_checkpoint_{args.save_checkpoint}"
        # )
        print(
            f"Model saved at {args.pretrainedPATH}saved_checkpoint_{args.save_checkpoint}"
        )
        args.save_checkpoint += 1

    return loss_hist, acc_hist
コード例 #4
0
    def test_forward_one_item(self):
        # Arrange
        k = 1
        predicted = torch.tensor([[0.1, 0.9]])
        target = torch.tensor([0])
        expected_loss = CrossEntropyLoss()(predicted, target)

        sut = TopKCrossEntropyLoss(k)

        # Act
        actual = sut.forward(predicted, target)

        # Assert
        self.assertEqual(round(expected_loss.item(), 2), round(actual.item(), 2))
コード例 #5
0
    def test_forward(self):
        # Arrange
        k = 2
        predicted = torch.tensor([[0.5, .5], [1.0, 0.0], [0.0, 1.0]])
        target = torch.tensor([0, 1, 0])
        expected_loss = CrossEntropyLoss()(predicted[torch.tensor([1, 2])], target[torch.tensor([1, 2])])

        sut = TopKCrossEntropyLoss(k)

        # Act
        actual = sut.forward(predicted, target)

        # Assert
        self.assertEqual(round(expected_loss.item(), 2), round(actual.item(), 2))
コード例 #6
0
def train(model, device, train_loader, optimizer, epoch, log_interval=1):
    model.to(device)
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        print(data.shape)

        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        print(target)
        loss = CrossEntropyLoss(output, target)
        loss.backward()
        optimizer.step()
        if batch_idx % log_interval == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                100. * batch_idx / len(train_loader), loss.item()))
コード例 #7
0
    def test_forward(self):
        # Arrange
        k = 2
        predicted = torch.tensor([[[0.5, .5], [.7, 0.3], [0.0, 1.0]],
                                  [[0.2, .8], [.8, 0.2], [0.0, 1.0]],
                                  [[0.5, .5], [1.0, 0.0], [0.0, 1.0]]])
        target = torch.tensor([[0, 1, 0], [0, 1, 0], [0, 1, 0]])
        indices_high_loss = torch.tensor([1, 2])

        expected_loss = CrossEntropyLoss()(
            predicted[indices_high_loss, :].permute(0, 2, 1),
            target[indices_high_loss, :])

        sut = TopKCrossEntropyLoss(k)

        # Act
        actual = sut.forward(predicted.permute(0, 2, 1), target)

        # Assert
        self.assertEqual(round(expected_loss.item(), 2),
                         round(actual.item(), 2))
コード例 #8
0
def Train(model,
          t,
          loader,
          eps_scheduler,
          norm,
          train,
          opt,
          bound_type,
          method='robust'):
    num_class = 10
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch()
        eps_scheduler.set_epoch_length(
            int((len(loader.dataset) + loader.batch_size - 1) /
                loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-20:
            batch_method = "natural"
        if train:
            opt.zero_grad()
        # generate specifications
        c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
            1) - torch.eye(num_class).type_as(data).unsqueeze(0)
        # remove specifications to self
        I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
            labels.data).unsqueeze(0)))
        c = (c[I].view(data.size(0), num_class - 1, num_class))
        # bound input for Linf norm used only
        if norm == np.inf:
            data_max = torch.reshape((1. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_min = torch.reshape((0. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1),
                                data_max)
            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1),
                                data_min)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels, c = data.cuda(), labels.cuda(), c.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        # Specify Lp norm perturbation.
        # When using Linf perturbation, we manually set element-wise bound x_L and x_U. eps is not used for Linf norm.
        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)
        x = BoundedTensor(data, ptb)

        output = model(x)
        regular_ce = CrossEntropyLoss()(
            output, labels)  # regular CrossEntropyLoss used for warming up
        meter.update('CE', regular_ce.item(), x.size(0))
        meter.update(
            'Err',
            torch.sum(
                torch.argmax(output, dim=1) != labels).cpu().detach().numpy() /
            x.size(0), x.size(0))

        if batch_method == "robust":
            if bound_type == "IBP":
                lb, ub = model.compute_bounds(IBP=True, C=c, method=None)
            elif bound_type == "CROWN":
                lb, ub = model.compute_bounds(IBP=False,
                                              C=c,
                                              method="backward",
                                              bound_upper=False)
            elif bound_type == "CROWN-IBP":
                # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method="backward")  # pure IBP bound
                # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
                factor = (eps_scheduler.get_max_eps() -
                          eps) / eps_scheduler.get_max_eps()
                ilb, iub = model.compute_bounds(IBP=True, C=c, method=None)
                if factor < 1e-5:
                    lb = ilb
                else:
                    clb, cub = model.compute_bounds(IBP=False,
                                                    C=c,
                                                    method="backward",
                                                    bound_upper=False)
                    lb = clb * factor + ilb * (1 - factor)

            # Pad zero at the beginning for each example, and use fake label "0" for all examples
            lb_padded = torch.cat((torch.zeros(
                size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb),
                                  dim=1)
            fake_labels = torch.zeros(size=(lb.size(0), ),
                                      dtype=torch.int64,
                                      device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
        if batch_method == "robust":
            loss = robust_ce
        elif batch_method == "natural":
            loss = regular_ce
        if train:
            loss.backward()
            eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))
        if batch_method != "natural":
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
            # If any margin is < 0 this example is counted as an error
            meter.update('Verified_Err',
                         torch.sum((lb < 0).any(dim=1)).item() / data.size(0),
                         data.size(0))
        meter.update('Time', time.time() - start)
        if i % 50 == 0 and train:
            print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
    print('[{:2d}:{:4d}]: eps={:.8f} {}'.format(t, i, eps, meter))
コード例 #9
0
    def do_training(train_fs, train_exs):
        """Runs BERT fine-tuning."""
        # Allows to write to enclosed variables global_step
        nonlocal global_step

        # Create the batched training data out of the features.
        train_data = create_tensor_dataset(train_fs)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        # Calculate the number of optimization steps.
        num_train_optimization_steps = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

        # Prepare optimizer.
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

        # Log some information about the training.
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_exs))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        # Set the model to training mode and train for X epochs.
        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            # Iterate over all batches.
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # Get the Logits and calculate the loss.
                logits = model(input_ids, token_type_ids=segment_ids, attention_mask=input_mask)
                loss = CrossEntropyLoss()(logits.view(-1, num_labels), label_ids.view(-1))

                # Scale the loss in gradient accumulation mode.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                # Calculate the gradients.
                loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                # Update the weights every gradient_accumulation_steps steps.
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                    tb_writer.add_scalar('lr', optimizer.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', loss.item(), global_step)
コード例 #10
0
    def loss(self, sample, supervised_loss_share: float = 0):
        """
        :param supervised_loss_share: share of supervised loss in total loss
        :param sample: {
            "xs": [
                [support_A_1, support_A_2, ...],
                [support_B_1, support_B_2, ...],
                [support_C_1, support_C_2, ...],
                ...
            ],
            "xq": [
                [query_A_1, query_A_2, ...],
                [query_B_1, query_B_2, ...],
                [query_C_1, query_C_2, ...],
                ...
            ]
        }
        :return:
        """
        xs = sample['xs']  # support
        xq = sample['xq']  # query

        n_class = len(xs)
        assert len(xq) == n_class
        n_support = len(xs[0])
        n_query = len(xq[0])

        target_inds = torch.arange(0, n_class).view(n_class, 1, 1).expand(
            n_class, n_query, 1).long()
        target_inds = Variable(target_inds, requires_grad=False).to(device)

        has_augment = "x_augment" in sample
        if has_augment:
            augmentations = sample["x_augment"]

            n_augmentations_samples = len(sample["x_augment"])
            n_augmentations_per_sample = [
                len(item['tgt_texts']) for item in augmentations
            ]
            assert len(set(n_augmentations_per_sample)) == 1
            n_augmentations_per_sample = n_augmentations_per_sample[0]

            supports = [item["sentence"] for xs_ in xs for item in xs_]
            queries = [item["sentence"] for xq_ in xq for item in xq_]
            augmentations_supports = [[item2 for item2 in item["tgt_texts"]]
                                      for item in sample["x_augment"]]
            augmentation_queries = [
                item["src_text"] for item in sample["x_augment"]
            ]

            # Encode
            x = supports + queries + [
                item2 for item1 in augmentations_supports for item2 in item1
            ] + augmentation_queries
            z = self.encoder.embed_sentences(x)
            z_dim = z.size(-1)

            # Dispatch
            z_support = z[:len(supports)].view(n_class, n_support,
                                               z_dim).mean(dim=[1])
            z_query = z[len(supports):len(supports) + len(queries)]
            z_aug_support = (
                z[len(supports) + len(queries):len(supports) + len(queries) +
                  n_augmentations_per_sample * n_augmentations_samples].view(
                      n_augmentations_samples, n_augmentations_per_sample,
                      z_dim).mean(dim=[1]))
            z_aug_query = z[-len(augmentation_queries):]
        else:
            # When not using augmentations
            supports = [item["sentence"] for xs_ in xs for item in xs_]
            queries = [item["sentence"] for xq_ in xq for item in xq_]

            # Encode
            x = supports + queries
            z = self.encoder.embed_sentences(x)
            z_dim = z.size(-1)

            # Dispatch
            z_support = z[:len(supports)].view(n_class, n_support,
                                               z_dim).mean(dim=[1])
            z_query = z[len(supports):len(supports) + len(queries)]

        if self.metric == "euclidean":
            supervised_dists = euclidean_dist(z_query, z_support)
            if has_augment:
                unsupervised_dists = euclidean_dist(z_aug_query, z_aug_support)
        elif self.metric == "cosine":
            supervised_dists = (-cosine_similarity(z_query, z_support) + 1) * 5
            if has_augment:
                unsupervised_dists = (
                    -cosine_similarity(z_aug_query, z_aug_support) + 1) * 5
        else:
            raise NotImplementedError

        from torch.nn import CrossEntropyLoss
        supervised_loss = CrossEntropyLoss()(-supervised_dists,
                                             target_inds.reshape(-1))
        _, y_hat_supervised = (-supervised_dists).max(1)
        acc_val_supervised = torch.eq(y_hat_supervised,
                                      target_inds.reshape(-1)).float().mean()

        if has_augment:
            # Unsupervised loss
            unsupervised_target_inds = torch.range(0, n_augmentations_samples -
                                                   1).to(device).long()
            unsupervised_loss = CrossEntropyLoss()(-unsupervised_dists,
                                                   unsupervised_target_inds)
            _, y_hat_unsupervised = (-unsupervised_dists).max(1)
            acc_val_unsupervised = torch.eq(
                y_hat_unsupervised,
                unsupervised_target_inds.reshape(-1)).float().mean()

            # Final loss
            assert 0 <= supervised_loss_share <= 1
            final_loss = (supervised_loss_share) * supervised_loss + (
                1 - supervised_loss_share) * unsupervised_loss

            return final_loss, {
                "metrics": {
                    "supervised_acc": acc_val_supervised.item(),
                    "unsupervised_acc": acc_val_unsupervised.item(),
                    "supervised_loss": supervised_loss.item(),
                    "unsupervised_loss": unsupervised_loss.item(),
                    "supervised_loss_share": supervised_loss_share,
                    "final_loss": final_loss.item(),
                },
                "supervised_dists": supervised_dists,
                "unsupervised_dists": unsupervised_dists,
                "target": target_inds
            }

        return supervised_loss, {
            "metrics": {
                "acc": acc_val_supervised.item(),
                "loss": supervised_loss.item(),
            },
            "dists": supervised_dists,
            "target": target_inds
        }
コード例 #11
0
def Train(model,
          t,
          loader,
          eps_scheduler,
          norm,
          train,
          opt,
          bound_type,
          method='robust',
          loss_fusion=True,
          final_node_name=None):
    num_class = 200
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch()
        eps_scheduler.set_epoch_length(
            int((len(loader.dataset) + loader.batch_size - 1) /
                loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()

    exp_module = get_exp_module(model)

    def get_bound_loss(x=None, c=None):
        if loss_fusion:
            bound_lower, bound_upper = False, True
        else:
            bound_lower, bound_upper = True, False

        if bound_type == 'IBP':
            lb, ub = model(method_opt="compute_bounds",
                           x=x,
                           IBP=True,
                           C=c,
                           method=None,
                           final_node_name=final_node_name,
                           no_replicas=True)
        elif bound_type == 'CROWN':
            lb, ub = model(method_opt="compute_bounds",
                           x=x,
                           IBP=False,
                           C=c,
                           method='backward',
                           bound_lower=bound_lower,
                           bound_upper=bound_upper)
        elif bound_type == 'CROWN-IBP':
            # lb, ub = model.compute_bounds(ptb=ptb, IBP=True, x=data, C=c, method='backward')  # pure IBP bound
            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
            factor = (eps_scheduler.get_max_eps() -
                      eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()
            ilb, iub = model(method_opt="compute_bounds",
                             x=x,
                             IBP=True,
                             C=c,
                             method=None,
                             final_node_name=final_node_name,
                             no_replicas=True)
            if factor < 1e-50:
                lb, ub = ilb, iub
            else:
                clb, cub = model(method_opt="compute_bounds",
                                 IBP=False,
                                 C=c,
                                 method='backward',
                                 bound_lower=bound_lower,
                                 bound_upper=bound_upper,
                                 final_node_name=final_node_name,
                                 no_replicas=True)
                if loss_fusion:
                    ub = cub * factor + iub * (1 - factor)
                else:
                    lb = clb * factor + ilb * (1 - factor)

        if loss_fusion:
            if isinstance(model, BoundDataParallel):
                max_input = model(get_property=True,
                                  node_class=BoundExp,
                                  att_name='max_input')
            else:
                max_input = exp_module.max_input
            return None, torch.mean(torch.log(ub) + max_input)
        else:
            # Pad zero at the beginning for each example, and use fake label '0' for all examples
            lb_padded = torch.cat((torch.zeros(
                size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb),
                                  dim=1)
            fake_labels = torch.zeros(size=(lb.size(0), ),
                                      dtype=torch.int64,
                                      device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
            return lb, robust_ce

    for i, (data, labels) in enumerate(loader):
        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-50:
            batch_method = "natural"
        if train:
            opt.zero_grad()
        # bound input for Linf norm used only
        if norm == np.inf:
            data_max = torch.reshape((1. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_min = torch.reshape((0. - loader.mean) / loader.std,
                                     (1, -1, 1, 1))
            data_ub = torch.min(data + (eps / loader.std).view(1, -1, 1, 1),
                                data_max)
            data_lb = torch.max(data - (eps / loader.std).view(1, -1, 1, 1),
                                data_min)
        else:
            data_ub = data_lb = data

        if list(model.parameters())[0].is_cuda:
            data, labels = data.cuda(), labels.cuda()
            data_lb, data_ub = data_lb.cuda(), data_ub.cuda()

        ptb = PerturbationLpNorm(norm=norm, eps=eps, x_L=data_lb, x_U=data_ub)
        x = BoundedTensor(data, ptb)
        if loss_fusion:
            if batch_method == 'natural' or not train:
                output = model(x, labels)
                regular_ce = torch.mean(torch.log(output))
            else:
                model(x, labels)
                regular_ce = torch.tensor(0., device=data.device)
            meter.update('CE', regular_ce.item(), x.size(0))
            x = (x, labels)
            c = None
        else:
            c = torch.eye(num_class).type_as(data)[labels].unsqueeze(
                1) - torch.eye(num_class).type_as(data).unsqueeze(0)
            # remove specifications to self
            I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(
                labels.data).unsqueeze(0)))
            c = (c[I].view(data.size(0), num_class - 1, num_class))
            x = (x, labels)
            output = model(x, final_node_name=final_node_name)
            regular_ce = CrossEntropyLoss()(
                output, labels)  # regular CrossEntropyLoss used for warming up
            meter.update('CE', regular_ce.item(), x[0].size(0))
            meter.update(
                'Err',
                torch.sum(torch.argmax(output, dim=1) != labels).item() /
                x[0].size(0), x[0].size(0))

        if batch_method == 'robust':
            # print(data.sum())
            lb, robust_ce = get_bound_loss(x=x, c=c)
            loss = robust_ce
        elif batch_method == 'natural':
            loss = regular_ce

        if train:
            loss.backward()

            if args.clip_grad_norm:
                grad_norm = torch.nn.utils.clip_grad_norm_(
                    model.parameters(), max_norm=args.clip_grad_norm)
                meter.update('grad_norm', grad_norm)

            if isinstance(eps_scheduler, AdaptiveScheduler):
                eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))

        if batch_method != 'natural':
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            if not loss_fusion:
                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
                # If any margin is < 0 this example is counted as an error
                meter.update(
                    'Verified_Err',
                    torch.sum((lb < 0).any(dim=1)).item() / data.size(0),
                    data.size(0))
        meter.update('Time', time.time() - start)

        if (i + 1) % 250 == 0 and train:
            logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(
                t, i + 1, eps, meter))

    logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))
    return meter
コード例 #12
0
def Train(model, t, loader, eps_scheduler, norm, train, opt, bound_type, method='robust', loss_fusion=True, final_node_name=None):
    meter = MultiAverageMeter()
    if train:
        model.train()
        eps_scheduler.train()
        eps_scheduler.step_epoch(verbose=False)
        eps_scheduler.set_epoch_length(int((len(loader.dataset) + loader.batch_size - 1) / loader.batch_size))
    else:
        model.eval()
        eps_scheduler.eval()
    
    # Used for loss-fusion. Get the exp operation in computational graph.
    exp_module = get_exp_module(model)

    def get_bound_loss(x=None, c=None):
        if loss_fusion:
            # When loss fusion is used, we need the upper bound for the final loss function.
            bound_lower, bound_upper = False, True
        else:
            # When loss fusion is not used, we need the lower bound for the logit layer.
            bound_lower, bound_upper = True, False

        if bound_type == 'IBP':
            lb, ub = model(method_opt="compute_bounds", x=x, C=c, method="IBP", final_node_name=final_node_name, no_replicas=True)
        elif bound_type == 'CROWN':
            lb, ub = model(method_opt="compute_bounds", x=x, C=c, method="backward",
                                          bound_lower=bound_lower, bound_upper=bound_upper)
        elif bound_type == 'CROWN-IBP':
            # we use a mixed IBP and CROWN-IBP bounds, leading to better performance (Zhang et al., ICLR 2020)
            # factor = (eps_scheduler.get_max_eps() - eps_scheduler.get_eps()) / eps_scheduler.get_max_eps()
            ilb, iub = model(method_opt="compute_bounds", x=x, C=c, method="IBP", final_node_name=final_node_name, no_replicas=True)
            lb, ub = model(method_opt="compute_bounds", C=c, method="CROWN-IBP",
                         bound_lower=bound_lower, bound_upper=bound_upper, final_node_name=final_node_name, average_A=True, no_replicas=True)
        if loss_fusion:
            # When loss fusion is enabled, we need to get the common factor before softmax.
            if isinstance(model, BoundDataParallel):
                max_input = model(get_property=True, node_class=BoundExp, att_name='max_input')
            else:
                max_input = exp_module.max_input
            return None, torch.mean(torch.log(ub) + max_input)
        else:
            # Pad zero at the beginning for each example, and use fake label '0' for all examples
            lb_padded = torch.cat((torch.zeros(size=(lb.size(0), 1), dtype=lb.dtype, device=lb.device), lb), dim=1)
            fake_labels = torch.zeros(size=(lb.size(0),), dtype=torch.int64, device=lb.device)
            robust_ce = CrossEntropyLoss()(-lb_padded, fake_labels)
            return lb, robust_ce

    for i, (data, labels) in enumerate(loader):
        # For unit test. We only use a small number of batches
        if args.truncate_data:
            if i >= args.truncate_data:
                break

        start = time.time()
        eps_scheduler.step_batch()
        eps = eps_scheduler.get_eps()
        # For small eps just use natural training, no need to compute LiRPA bounds
        batch_method = method
        if eps < 1e-50:
            batch_method = "natural"
        if train:
            opt.zero_grad()

        if list(model.parameters())[0].is_cuda:
            data, labels = data.cuda(), labels.cuda()

        model.ptb.eps = eps
        x = data
        if loss_fusion:
            if batch_method == 'natural' or not train:
                output = model(x, labels)  # , disable_multi_gpu=True
                regular_ce = torch.mean(torch.log(output))
            else:
                model(x, labels)
                regular_ce = torch.tensor(0., device=data.device)
            meter.update('CE', regular_ce.item(), x.size(0))
            x = (x, labels)
            c = None
        else:
            # Generate speicification matrix (when loss fusion is not used).
            c = torch.eye(num_class).type_as(data)[labels].unsqueeze(1) - torch.eye(num_class).type_as(data).unsqueeze(
                0)
            # remove specifications to self.
            I = (~(labels.data.unsqueeze(1) == torch.arange(num_class).type_as(labels.data).unsqueeze(0)))
            c = (c[I].view(data.size(0), num_class - 1, num_class))
            x = (x, labels)
            output = model(x, final_node_name=final_node_name)
            regular_ce = CrossEntropyLoss()(output, labels)  # regular CrossEntropyLoss used for warming up
            meter.update('CE', regular_ce.item(), x[0].size(0))
            meter.update('Err', torch.sum(torch.argmax(output, dim=1) != labels).item() / x[0].size(0), x[0].size(0))

        if batch_method == 'robust':
            lb, robust_ce = get_bound_loss(x=x, c=c)
            loss = robust_ce
        elif batch_method == 'natural':
            loss = regular_ce

        if train:
            loss.backward()

            if args.clip_grad_norm:
                grad_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=args.clip_grad_norm)
                meter.update('grad_norm', grad_norm)

            if isinstance(eps_scheduler, AdaptiveScheduler):
                eps_scheduler.update_loss(loss.item() - regular_ce.item())
            opt.step()
        meter.update('Loss', loss.item(), data.size(0))

        if batch_method != 'natural':
            meter.update('Robust_CE', robust_ce.item(), data.size(0))
            if not loss_fusion:
                # For an example, if lower bounds of margins is >0 for all classes, the output is verifiably correct.
                # If any margin is < 0 this example is counted as an error
                meter.update('Verified_Err', torch.sum((lb < 0).any(dim=1)).item() / data.size(0), data.size(0))
        meter.update('Time', time.time() - start)

        if (i + 1) % 50 == 0 and train:
            logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))

    logger.log('[{:2d}:{:4d}]: eps={:.12f} {}'.format(t, i + 1, eps, meter))
    return meter
コード例 #13
0
def distill(args,
            output_model_file,
            processor,
            label_list,
            tokenizer,
            device,
            n_gpu,
            tensorboard_logger,
            eval_data=None):
    assert args.kd_policy is not None
    model = args.kd_policy.student
    args.kd_policy.teacher.eval()
    num_labels = len(args.labels)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    save_best_model = eval_data is not None and args.eval_interval > 0

    train_examples = processor.get_train_examples(args.data_dir)
    num_train_steps = int(
        len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)
    optimizer, t_total = get_optimizer(args, model, num_train_steps)

    train_data = prepare(args, processor, label_list, tokenizer, 'train')
    logger.info("***** Running distillation *****")
    logger.info("  Num examples = %d", len(train_examples))
    logger.info("  Batch size = %d", args.train_batch_size)
    logger.info("  Num steps = %d", num_train_steps)

    if args.local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

    train_steps = 0
    best_eval_accuracy = 0
    for epoch in trange(int(args.num_train_epochs), desc="Epoch", dynamic_ncols=True):
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        args.kd_policy.on_epoch_begin(model, None, None)

        for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", dynamic_ncols=True)):
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            model.train()
            logits = args.kd_policy.forward(input_ids, segment_ids, input_mask)
            loss = CrossEntropyLoss()(logits.view(-1, num_labels), label_ids.view(-1))
            loss = args.kd_policy.before_backward_pass(model, epoch, None, None, loss, None).overall_loss
            if n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                optimizer.backward(loss)
            else:
                loss.backward()

            train_steps += 1
            tensorboard_logger.add_scalar('distillation_train_loss', loss.item(), train_steps)

            tr_loss += loss.item()
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1
            if (step + 1) % args.gradient_accumulation_steps == 0:
                # modify learning rate with special warm up BERT uses
                lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
                for param_group in optimizer.param_groups:
                    param_group['lr'] = lr_this_step
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

            if save_best_model and train_steps % args.eval_interval == 0:
                eval_loss, eval_accuracy, _ = eval(args, model, eval_data, device, verbose=False)
                tensorboard_logger.add_scalar('distillation_dev_loss', eval_loss, train_steps)
                tensorboard_logger.add_scalar('distillation_dev_accuracy', eval_accuracy, train_steps)
                if eval_accuracy > best_eval_accuracy:
                    save_model(model, output_model_file)
                    best_eval_accuracy = eval_accuracy

        args.kd_policy.on_epoch_end(model, None, None)

    if save_best_model:
        eval_loss, eval_accuracy, _ = eval(args, model, eval_data, device, verbose=False)
        if eval_accuracy > best_eval_accuracy:
            save_model(model, output_model_file)
    else:
        save_model(model, output_model_file)

    return global_step, tr_loss / nb_tr_steps
コード例 #14
0
def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="输入数据dir。应该包含任务的.tsv文件(或其他数据文件)。")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help=
        "Bert pre-trained model selected in the list: bert-base-uncased, bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, bert-base-multilingual-cased, bert-base-chinese."
    )
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="训练任务的名称")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="将写入模型预测和checkpoints的输出目录。 ")
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="您希望将从s3下载的预训练模型存储在何处")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help="WordPiece tokenization 后输入序列的最大总长度,大于这个的序列将被截断,小于的padded")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="如果您使用的是uncased模型,请设置此标志。")
    parser.add_argument("--train_batch_size",
                        default=64,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=256,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    # ??????????????????????????????
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% of training."
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.0 (default value): dynamic loss scaling.Positive power of 2: static loss scaling value.\n"
    )
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    # if args.server_ip and args.server_port:
    #     # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
    #     import ptvsd
    #     print("Waiting for debugger attach")
    #     ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
    #     ptvsd.wait_for_attach()

    processors = {
        # "cola": ColaProcessor,
        # "mnli": MnliProcessor,
        # "mnli-mm": MnliMismatchedProcessor,
        # "mrpc": MrpcProcessor,
        # "sst-2": Sst2Processor,
        # "sts-b": StsbProcessor,
        # "qqp": QqpProcessor,
        # "qnli": QnliProcessor,
        "rte": RteProcessor
        # "wnli": WnliProcessor,
    }

    output_modes = {
        # "cola": "classification",
        # "mnli": "classification",
        # "mrpc": "classification",
        # "sst-2": "classification",
        # "sts-b": "regression",
        # "qqp": "classification",
        # "qnli": "classification",
        "rte": "classification"
        # "wnli": "classification",
    }

    if args.local_rank == -1 or args.no_cuda:  # 未指定GPU,或无GPU
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:  # 分布式
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1  # ??????????多GPU???????
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs # ?????单GPU没有分布式??????
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    # 如果显存不足,假设原来的batch size=10,数据总量为1000,那么一共需要100train steps,同时一共进行100次梯度更新。
    # 若是显存不够,我们需要减小batch size,我们设置gradient_accumulation_steps=2,那么我们新的batch_size=10/2=5,
    # 我们需要运行两次,才能在内存中放入10条数据,梯度更新的次数不变为100次,那么我们的train_steps=200
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if n_gpu > 0:  # 多GPU
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()  # RteProcessor
    output_mode = output_modes[task_name]  # "classification"

    label_list = processor.get_labels()  # ["entailment", "not_entailment"]
    num_labels = len(label_list)

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_TRANSFORMERS_CACHE), 'distributed_{}'.format(
            args.local_rank))
    # model = BertForSequenceClassification.from_pretrained(args.bert_model,
    #           cache_dir=cache_dir,
    #           num_labels=num_labels)
    # tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    model = BertForSequenceClassification.from_pretrained(
        'bert-base-uncased', num_labels=num_labels)  # 2个标签
    if args.fp16:
        model.half()
    model.to(device)
    if n_gpu > 1:  # 多GPU
        model = torch.nn.DataParallel(model)

    tokenizer = BertTokenizer.from_pretrained('bert-base-uncased',
                                              do_lower_case=args.do_lower_case)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']  # 不weight_decay
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    # nd 在不在 n 中如果在把p放进去
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:
        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate)

    if args.do_train:
        num_train_steps = None
        # train_examples = processor.get_train_examples_wenpeng('/home/wyin3/Datasets/glue_data/RTE/train.tsv')
        train_examples, seen_types = processor.get_examples_Wikipedia_train(
            '/home/zut_csi/tomding/zs/BenchmarkingZeroShotData/tokenized_wiki2categories.txt',
            100000)
        # /export/home/Dataset/wikipedia/parsed_output/tokenized_wiki/tokenized_wiki2categories.txt', 100000) #train_pu_half_v1.txt
        # seen_classes=[0,2,4,6,8]
        eval_examples, eval_label_list, eval_hypo_seen_str_indicator, eval_hypo_2_type_index = processor.get_examples_emotion_test(
            '/home/zut_csi/tomding/zs/BenchmarkingZeroShot/emotion/dev.txt',
            seen_types)
        # /export/home/Dataset/Stuttgart_Emotion/unify-emotion-datasets-master/zero-shot-split/dev.txt', seen_types)
        test_examples, test_label_list, test_hypo_seen_str_indicator, test_hypo_2_type_index = processor.get_examples_emotion_test(
            '/home/zut_csi/tomding/zs/BenchmarkingZeroShot/emotion/test.txt',
            seen_types)
        # /export/home/Dataset/Stuttgart_Emotion/unify-emotion-datasets-master/zero-shot-split/test.txt', seen_types)

        train_features, eval_features, test_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length,
            tokenizer, output_mode), convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer,
                output_mode), convert_examples_to_features(
                    test_examples, label_list, args.max_seq_length, tokenizer,
                    output_mode)
        all_input_ids, eval_all_input_ids, test_all_input_ids = torch.tensor(
            [f.input_ids for f in train_features],
            dtype=torch.long), torch.tensor(
                [f.input_ids for f in eval_features],
                dtype=torch.long), torch.tensor(
                    [f.input_ids for f in test_features], dtype=torch.long)
        all_input_mask, eval_all_input_mask, test_all_input_mask = torch.tensor(
            [f.input_mask for f in train_features],
            dtype=torch.long), torch.tensor(
                [f.input_mask for f in eval_features],
                dtype=torch.long), torch.tensor(
                    [f.input_mask for f in test_features], dtype=torch.long)
        all_segment_ids, eval_all_segment_ids, test_all_segment_ids = torch.tensor(
            [f.segment_ids for f in train_features],
            dtype=torch.long), torch.tensor(
                [f.segment_ids for f in eval_features],
                dtype=torch.long), torch.tensor(
                    [f.segment_ids for f in test_features], dtype=torch.long)
        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.float)
        eval_all_label_ids, test_all_label_ids = torch.tensor(
            [f.label_id for f in eval_features],
            dtype=torch.long), torch.tensor(
                [f.label_id for f in test_features], dtype=torch.long)
        train_data, eval_data, test_data = TensorDataset(
            all_input_ids, all_input_mask,
            all_segment_ids, all_label_ids), TensorDataset(
                eval_all_input_ids, eval_all_input_mask, eval_all_segment_ids,
                eval_all_label_ids), TensorDataset(test_all_input_ids,
                                                   test_all_input_mask,
                                                   test_all_segment_ids,
                                                   test_all_label_ids)
        train_sampler, eval_sampler, test_sampler = RandomSampler(
            train_data), SequentialSampler(eval_data), SequentialSampler(
                test_data)
        eval_dataloader, test_dataloader, train_dataloader = DataLoader(
            eval_data, sampler=eval_sampler,
            batch_size=args.eval_batch_size), DataLoader(
                test_data,
                sampler=test_sampler,
                batch_size=args.eval_batch_size), DataLoader(
                    train_data,
                    sampler=train_sampler,
                    batch_size=args.train_batch_size)

        # ??????????????batch_size 已经除 args.gradient_accumulation_steps?????????????????
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_steps = num_train_steps // torch.distributed.get_world_size(
            )  # 全局的整个的进程数

        max_test_unseen_acc, max_dev_unseen_acc, max_dev_seen_acc, max_overall_acc = 0.0, 0.0, 0.0, 0.0  #
        logger.info(
            '******************************************************  Running_training  ***************************************************'
        )
        logger.info("Num_examples:{} Batch_size:{} Num_steps:{}".format(
            len(train_examples), args.train_batch_size, num_train_steps))
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            train_loss = 0
            for train_step, batch_data in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()

                batch_data = tuple(b.to(device) for b in batch_data)
                input_ids, input_mask, segment_ids, label_ids = batch_data
                logits = model(input_ids, segment_ids, input_mask,
                               labels=None)[0]
                tmp_train_loss = CrossEntropyLoss()(logits.view(
                    -1, num_labels), label_ids.view(-1))
                if n_gpu > 1:  # 多GPU
                    tmp_train_loss = tmp_train_loss.mean(
                    )  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    tmp_train_loss = tmp_train_loss / args.gradient_accumulation_steps
                tmp_train_loss.backward()
                train_loss += tmp_train_loss.item()

                optimizer.step()
                optimizer.zero_grad()

                if (train_step + 1
                    ) % 200 == 0:  # start evaluate on dev set after this epoch

                    def et(et_dataloader, max_et_unseen_acc, et_label_list,
                           et_hypo_seen_str_indicator, et_hypo_2_type_index):
                        model.eval()
                        et_loss, et_step, preds = 0, 0, []
                        for input_ids, input_mask, segment_ids, label_ids in et_dataloader:
                            input_ids, input_mask, segment_ids, label_ids = input_ids.to(
                                device), input_mask.to(device), segment_ids.to(
                                    device), label_ids.to(device)
                            with torch.no_grad():
                                logits = model(input_ids,
                                               segment_ids,
                                               input_mask,
                                               labels=None)[0]
                            tmp_et_loss = CrossEntropyLoss()(logits.view(
                                -1, num_labels), label_ids.view(-1))
                            et_loss += tmp_et_loss.mean().item()
                            et_step += 1
                            if len(preds) == 0:
                                preds.append(logits.detach().cpu().numpy())
                                # 进行反向传播时,到该调用detach()的Variable就会停止,不能再继续向前进行传播.
                                # cpu()函数作用是将数据从GPU上复制到memory上,相对应的函数是cuda()
                            else:
                                preds[0] = np.append(
                                    preds[0],
                                    logits.detach().cpu().numpy(),
                                    axis=0)
                        et_loss = et_loss / et_step
                        preds = preds[0]
                        '''
                        preds: size*2 (entail, not_entail)
                        wenpeng added a softxmax so that each row is a prob vec
                        '''
                        pred_probs = softmax(preds, axis=1)[:, 0]
                        pred_binary_labels_harsh, pred_binary_labels_loose = [], []
                        for i in range(preds.shape[0]):
                            pred_binary_labels_harsh.append(
                                0
                            ) if preds[i][0] > preds[i][
                                1] + 0.1 else pred_binary_labels_harsh.append(
                                    1)
                            pred_binary_labels_loose.append(
                                0) if preds[i][0] > preds[i][
                                    1] else pred_binary_labels_loose.append(1)

                        seen_acc, unseen_acc = evaluate_emotion_zeroshot_TwpPhasePred(
                            pred_probs, pred_binary_labels_harsh,
                            pred_binary_labels_loose, et_label_list,
                            et_hypo_seen_str_indicator, et_hypo_2_type_index,
                            seen_types)
                        # result = compute_metrics('F1', preds, all_label_ids.numpy())
                        loss = train_loss / train_step if args.do_train else None
                        # test_acc = mean_f1#result.get("f1")
                        if unseen_acc > max_et_unseen_acc:
                            max_et_unseen_acc = unseen_acc
                        print(
                            'seen_f1:{} unseen_f1:{} max_unseen_f1:{}'.format(
                                seen_acc, unseen_acc, max_et_unseen_acc))
                        return max_et_unseen_acc

                    # if seen_acc+unseen_acc > max_overall_acc:
                    #     max_overall_acc = seen_acc + unseen_acc
                    # if seen_acc > max_dev_seen_acc:
                    #     max_dev_seen_acc = seen_acc

                    logger.info(
                        '*********************  Running evaluation  *********************'
                    )
                    logger.info("Num_examples:{} Batch_size:{}".format(
                        len(eval_examples), args.eval_batch_size))
                    max_dev_unseen_acc = et(eval_dataloader,
                                            max_dev_unseen_acc,
                                            eval_label_list,
                                            eval_hypo_seen_str_indicator,
                                            eval_hypo_2_type_index)
                    logger.info(
                        '*********************    Running testing   *********************'
                    )
                    logger.info("Num_examples:{} Batch_size:{}".format(
                        len(test_examples), args.eval_batch_size))
                    max_test_unseen_acc = et(test_dataloader,
                                             max_test_unseen_acc,
                                             test_label_list,
                                             test_hypo_seen_str_indicator,
                                             test_hypo_2_type_index)