예제 #1
0
    def train_epoch(self, train_dataloader):
        for step, batch in enumerate(tqdm(train_dataloader, desc="Training")):
            self.model.train()
            batch = tuple(t.to(self.args.device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            logits = self.model(input_ids, segment_ids, input_mask)

            if self.args.is_multilabel:
                loss = F.binary_cross_entropy_with_logits(
                    logits, label_ids.float())
            else:
                loss = F.cross_entropy(logits, torch.argmax(label_ids, dim=1))

            if self.args.n_gpu > 1:
                loss = loss.mean()
            if self.args.gradient_accumulation_steps > 1:
                loss = loss / self.args.gradient_accumulation_steps

            if self.args.fp16:
                self.optimizer.backward(loss)
            else:
                loss.backward()

            self.tr_loss += loss.item()
            self.nb_tr_steps += 1
            if (step + 1) % self.args.gradient_accumulation_steps == 0:
                if self.args.fp16:
                    lr_this_step = self.args.learning_rate * warmup_linear(
                        self.iterations / self.num_train_optimization_steps,
                        self.args.warmup_proportion)
                    for param_group in self.optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                self.optimizer.step()
                self.optimizer.zero_grad()
                self.iterations += 1
예제 #2
0
def main():
    args = run_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_infer:
        raise ValueError(
            "At least one of `do_train` or `do_infer` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    processor = BingjianFrenchProcessor(args.max_seq_length, args.bert_model)

    train_features = None
    num_train_optimization_steps = None
    if args.do_train:
        train_features = processor.get_train_features(args.data_dir)
        num_train_optimization_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertLstmCrf.from_pretrained(
        args.bert_model, tagset_size=processor.get_tagset_size())
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)
        # model = DataParallelModel(model)  # todo 多GPU的负载均衡

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    # 加载词典 (仅仅用于可视化评估)
    vocab_dict = collections.OrderedDict()
    index = 0
    with open(os.path.join(args.bert_model, "vocab.txt"),
              "r",
              encoding="utf-8") as reader:
        while True:
            token = reader.readline()
            if not token:
                break
            token = token.strip()
            vocab_dict[index] = token
            index += 1

    if args.do_train:
        eval_features = processor.get_dev_features(args.data_dir)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        nb_tr_steps = 0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples = 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                input_mask = input_mask.byte()
                total_loss, scores, tag_seq = model(
                    input_ids, segment_ids, input_mask,
                    label_ids)  # num_gpu个 每个的shape:
                loss = total_loss

                # todo 如何加上负载均衡

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # do eval
                if global_step % args.eval_freq == 0 and global_step > 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Num examples = %d", len(eval_features))
                    logger.info("  Batch size = %d", args.eval_batch_size)
                    all_input_ids = torch.tensor(
                        [f.input_ids for f in eval_features], dtype=torch.long)
                    all_input_mask = torch.tensor(
                        [f.input_mask for f in eval_features],
                        dtype=torch.long)
                    all_segment_ids = torch.tensor(
                        [f.segment_ids for f in eval_features],
                        dtype=torch.long)
                    all_label_ids = torch.tensor(
                        [f.label_id for f in eval_features], dtype=torch.long)
                    eval_data = TensorDataset(all_input_ids, all_input_mask,
                                              all_segment_ids, all_label_ids)
                    # Run prediction for full data
                    eval_sampler = SequentialSampler(eval_data)
                    eval_dataloader = DataLoader(
                        eval_data,
                        sampler=eval_sampler,
                        batch_size=args.eval_batch_size)

                    model.eval()
                    eval_loss, eval_accuracy = 0, 0
                    nb_eval_steps, nb_eval_examples = 0, 0
                    correct_count_all = 0
                    count_all = 0

                    for input_ids, input_mask, segment_ids, label_ids in tqdm(
                            eval_dataloader, desc="Evaluating"):
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)

                        nb_eval_steps += 1

                        with torch.no_grad():
                            input_mask = input_mask.byte()
                            total_loss, scores, tag_seq = model(
                                input_ids, segment_ids, input_mask, label_ids)

                        eval_loss += total_loss.mean().item()
                        # todo 用tag_seq计算准确率和召回,目前是在cpu中计算,有待优化
                        tag_pred = tag_seq
                        ori_correct_mat = (
                            torch.tensor(tag_pred).cpu() == label_ids.cpu())
                        masked_correct_mat = torch.mul(ori_correct_mat.long(),
                                                       input_mask.long().cpu())
                        correct_count = masked_correct_mat.sum().tolist()

                        correct_count_all += correct_count
                        count_all += input_mask.long().sum().tolist()

                        logger.info("***** Eval examples *****")
                        show_sent_words = ""
                        show_sent_tags_stan = ""
                        show_sent_tags_pred = ""
                        for i in range(5):
                            for j in range(input_mask[i].sum().tolist()):
                                show_sent_words += vocab_dict.get(
                                    input_ids[i].tolist()[j]) + " "
                                show_sent_tags_pred += processor.label_dict_inv.get(
                                    tag_pred[i].tolist()[j]) + " "
                                show_sent_tags_stan += processor.label_dict_inv.get(
                                    label_ids[i].tolist()[j]) + " "
                            print("句子:\t", show_sent_words)
                            print("预测的tags:\t", show_sent_tags_pred)
                            print("正确的tags:\t", show_sent_tags_stan)
                            print()
                            show_sent_words = ""
                            show_sent_tags_stan = ""
                            show_sent_tags_pred = ""

                    eval_accuracy = correct_count_all / count_all
                    logger.info("***** Eval results *****")
                    logger.info("eval_accuracy = %f", eval_accuracy)
                    logger.info("train_loss = %f", tr_loss / nb_tr_steps)
                    logger.info("eval_loss = %f", eval_loss / nb_eval_steps)

                    tr_loss = 0
                    nb_tr_steps = 0

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

    if args.do_infer:
        test_features = processor.get_test_features(args.data_dir)
        logger.info("***** Running inference *****")
        logger.info("  Num examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.infer_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in test_features],
                                     dtype=torch.long)
        infer_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        # Run prediction for full data
        infer_sampler = SequentialSampler(infer_data)
        infer_dataloader = DataLoader(infer_data,
                                      sampler=infer_sampler,
                                      batch_size=args.infer_batch_size)

        model.eval()
        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                infer_dataloader, desc="Infer"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                input_mask = input_mask.byte()
                total_loss, scores, tag_seq = model(input_ids, segment_ids,
                                                    input_mask, label_ids)

            tag_pred = tag_seq
            for i in range(len(tag_pred)):
                sent_words = ""
                sent_tags = ""
                for j in range(input_mask[i].sum().tolist()):
                    sent_words += vocab_dict.get(
                        input_ids[i].tolist()[j]) + " "
                    sent_tags += processor.label_dict_inv.get(
                        tag_pred[i].tolist()[j]) + " "
                print(sent_words)
                print(sent_tags + "\n")

        logger.info("***** Infer finished *****")
예제 #3
0
def main():
    args = run_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_infer:
        raise ValueError(
            "At least one of `do_train` or `do_infer` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    processor = ColaProcessor()
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    if args.upper_model == "Linear":
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)
    elif args.upper_model == "CNN":
        model = BertCnn.from_pretrained(args.bert_model,
                                        num_labels=num_labels,
                                        seq_len=args.max_seq_length)
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        # model = torch.nn.DataParallel(model)
        model = DataParallelModel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                model.train()
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                predictions = model(input_ids, segment_ids, input_mask,
                                    label_ids)
                for i in range(len(predictions)):
                    predictions[i] = predictions[i].view(-1, num_labels)
                loss_fct = CrossEntropyLoss()
                loss_fct_parallel = DataParallelCriterion(loss_fct)
                loss = loss_fct_parallel(predictions, label_ids.view(-1))
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                # do eval
                if global_step % args.eval_freq == 0 and global_step > 0:
                    logger.info("***** Running evaluation *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)
                    all_input_ids = torch.tensor(
                        [f.input_ids for f in eval_features], dtype=torch.long)
                    all_input_mask = torch.tensor(
                        [f.input_mask for f in eval_features],
                        dtype=torch.long)
                    all_segment_ids = torch.tensor(
                        [f.segment_ids for f in eval_features],
                        dtype=torch.long)
                    all_label_ids = torch.tensor(
                        [f.label_id for f in eval_features], dtype=torch.long)
                    eval_data = TensorDataset(all_input_ids, all_input_mask,
                                              all_segment_ids, all_label_ids)
                    # Run prediction for full data
                    eval_sampler = SequentialSampler(eval_data)
                    eval_dataloader = DataLoader(
                        eval_data,
                        sampler=eval_sampler,
                        batch_size=args.eval_batch_size)

                    model.eval()
                    eval_loss, eval_accuracy = 0, 0
                    nb_eval_steps, nb_eval_examples = 0, 0

                    for input_ids, input_mask, segment_ids, label_ids in tqdm(
                            eval_dataloader, desc="Evaluating"):
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)

                        with torch.no_grad():
                            eval_preds = model(input_ids, segment_ids,
                                               input_mask, label_ids)

                        # 计算loss
                        for i in range(len(eval_preds)):
                            eval_preds[i] = eval_preds[i].view(-1, num_labels)
                        loss = loss_fct_parallel(eval_preds,
                                                 label_ids.view(-1))
                        if n_gpu > 1:
                            loss = loss.mean(
                            )  # mean() to average on multi-gpu.
                        if args.gradient_accumulation_steps > 1:
                            loss = loss / args.gradient_accumulation_steps
                        tmp_eval_loss = loss

                        eval_preds = torch.cat(
                            eval_preds)  # shape: [batch_size, num_labels]
                        logits = eval_preds.detach().cpu().numpy()
                        label_ids = label_ids.to('cpu').numpy()
                        tmp_eval_accuracy = accuracy(logits, label_ids)

                        eval_loss += tmp_eval_loss.mean().item()
                        eval_accuracy += tmp_eval_accuracy

                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                    eval_loss = eval_loss / nb_eval_steps
                    eval_accuracy = eval_accuracy / nb_eval_examples
                    loss = tr_loss / nb_tr_steps if args.do_train else None
                    result = {
                        'eval_loss': eval_loss,
                        'eval_accuracy': eval_accuracy,
                        'global_step': global_step,
                        'loss': loss
                    }
                    logger.info("***** Eval results *****")
                    for key in sorted(result.keys()):
                        logger.info("  %s = %s", key, str(result[key]))

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

    if args.do_infer:
        infer_examples = processor.get_infer_examples(args.data_dir)
        infer_features = convert_examples_to_features(infer_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running Inference *****")
        logger.info("  Num examples = %d", len(infer_examples))
        logger.info("  Batch size = %d", args.infer_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in infer_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in infer_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in infer_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in infer_features],
                                     dtype=torch.long)
        infer_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        # Run prediction for full data
        infer_sampler = SequentialSampler(infer_data)
        infer_dataloader = DataLoader(infer_data,
                                      sampler=infer_sampler,
                                      batch_size=args.infer_batch_size)

        model.eval()

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                infer_dataloader, desc="Inference"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                infer_preds = model(input_ids, segment_ids, input_mask,
                                    label_ids)

            for i in range(len(infer_preds)):
                infer_preds[i] = infer_preds[i].view(-1, num_labels)

            infer_preds = torch.cat(
                infer_preds)  # shape: [batch_size, num_labels]
            logits = infer_preds.detach().cpu().numpy()
            outputs = np.argmax(logits, axis=1)
            print(outputs)
        logger.info("***** Infer finished *****")