Пример #1
0
 def test_tokenizer_no_lower(self):
     tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
     tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
     self.assertListEqual(tokens, [SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was', SPIECE_UNDERLINE + u'b', u'or',
                                   u'n', SPIECE_UNDERLINE + u'in', SPIECE_UNDERLINE + u'',
                                   u'9', u'2', u'0', u'0', u'0', u',', SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
                                   SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u'se', u'.'])
Пример #2
0
 def test_tokenizer_no_lower(self):
     tokenizer = XLNetTokenizer(SAMPLE_VOCAB, do_lower_case=False)
     tokens = tokenizer.tokenize("I was born in 92000, and this is falsé.")
     self.assertListEqual(
         tokens,
         [
             SPIECE_UNDERLINE + "I",
             SPIECE_UNDERLINE + "was",
             SPIECE_UNDERLINE + "b",
             "or",
             "n",
             SPIECE_UNDERLINE + "in",
             SPIECE_UNDERLINE + "",
             "9",
             "2",
             "0",
             "0",
             "0",
             ",",
             SPIECE_UNDERLINE + "and",
             SPIECE_UNDERLINE + "this",
             SPIECE_UNDERLINE + "is",
             SPIECE_UNDERLINE + "f",
             "al",
             "se",
             ".",
         ],
     )
Пример #3
0
def xlnetTokenizer(*args, **kwargs):
    """
    Instantiate a XLNet sentencepiece tokenizer for XLNet from a pre-trained vocab file.
    Peculiarities:
        - require Google sentencepiece (https://github.com/google/sentencepiece)

    Args:
    pretrained_model_name_or_path: Path to pretrained model archive
                                   or one of pre-trained vocab configs below.
                                       * xlnet-large-cased
    Keyword args:
    special_tokens: Special tokens in vocabulary that are not pretrained
                    Default: None
    max_len: An artificial maximum length to truncate tokenized sequences to;
             Effective maximum length is always the minimum of this
             value (if specified) and the underlying model's
             sequence length.
             Default: None

    Example:
        >>> import torch
        >>> tokenizer = torch.hub.load('huggingface/pytorch-transformers', 'xlnetTokenizer', 'xlnet-large-cased')

        >>> text = "Who was Jim Henson ?"
        >>> indexed_tokens = tokenizer.encode(tokenized_text)
    """
    tokenizer = XLNetTokenizer.from_pretrained(*args, **kwargs)
    return tokenizer
Пример #4
0
    def test_full_tokenizer(self):
        tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)

        tokens = tokenizer.tokenize(u'This is a test')
        self.assertListEqual(tokens, [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])

        self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                             [285, 46, 10, 170, 382])

        tokens = tokenizer.tokenize(u"I was born in 92000, and this is falsé.")
        self.assertListEqual(tokens, [
            SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was',
            SPIECE_UNDERLINE + u'b', u'or', u'n', SPIECE_UNDERLINE + u'in',
            SPIECE_UNDERLINE + u'', u'9', u'2', u'0', u'0', u'0', u',',
            SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
            SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
            u'é', u'.'
        ])
        ids = tokenizer.convert_tokens_to_ids(tokens)
        self.assertListEqual(ids, [
            8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46, 72,
            80, 6, 0, 4
        ])

        back_tokens = tokenizer.convert_ids_to_tokens(ids)
        self.assertListEqual(back_tokens, [
            SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was',
            SPIECE_UNDERLINE + u'b', u'or', u'n', SPIECE_UNDERLINE + u'in',
            SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
            SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
            SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
            u'<unk>', u'.'
        ])
Пример #5
0
    def test_sequence_builders(self):
        tokenizer = XLNetTokenizer.from_pretrained("xlnet-base-cased")

        text = tokenizer.encode("sequence builders")
        text_2 = tokenizer.encode("multi-sequence build")

        encoded_sentence = tokenizer.add_special_tokens_single_sentence(text)
        encoded_pair = tokenizer.add_special_tokens_sentences_pair(
            text, text_2)

        assert encoded_sentence == text + [4, 3]
        assert encoded_pair == text + [4] + text_2 + [4, 3]
Пример #6
0
def init_params():
    processors = {"sentiment_analysis": SentiAnalysisProcessor}
    task_name = args.task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))
    processor = processors[task_name]()
    if args.model_type == 'bert':
        tokenizer = BertTokenizer(vocab_file=args.VOCAB_FILE)
    elif args.model_type == 'xlnet':
        tokenizer = XLNetTokenizer.from_pretrained(
            os.path.join(args.ROOT_DIR, args.xlnet_model),
            do_lower_case=args.do_lower_case)
    return processor, tokenizer
Пример #7
0
    def test_full_tokenizer(self):
        tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)

        with TemporaryDirectory() as tmpdirname:
            tokenizer.save_pretrained(tmpdirname)

            input_text = u"This is a test"
            output_text = u"This is a test"

            create_and_check_tokenizer_commons(self, input_text, output_text,
                                               XLNetTokenizer, tmpdirname)

            tokens = tokenizer.tokenize(u'This is a test')
            self.assertListEqual(tokens,
                                 [u'▁This', u'▁is', u'▁a', u'▁t', u'est'])

            self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                                 [285, 46, 10, 170, 382])

            tokens = tokenizer.tokenize(
                u"I was born in 92000, and this is falsé.")
            self.assertListEqual(tokens, [
                SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was',
                SPIECE_UNDERLINE + u'b', u'or', u'n', SPIECE_UNDERLINE + u'in',
                SPIECE_UNDERLINE + u'', u'9', u'2', u'0', u'0', u'0', u',',
                SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
                SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
                u'é', u'.'
            ])
            ids = tokenizer.convert_tokens_to_ids(tokens)
            self.assertListEqual(ids, [
                8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66, 46,
                72, 80, 6, 0, 4
            ])

            back_tokens = tokenizer.convert_ids_to_tokens(ids)
            self.assertListEqual(back_tokens, [
                SPIECE_UNDERLINE + u'I', SPIECE_UNDERLINE + u'was',
                SPIECE_UNDERLINE + u'b', u'or', u'n', SPIECE_UNDERLINE + u'in',
                SPIECE_UNDERLINE + u'', u'<unk>', u'2', u'0', u'0', u'0', u',',
                SPIECE_UNDERLINE + u'and', SPIECE_UNDERLINE + u'this',
                SPIECE_UNDERLINE + u'is', SPIECE_UNDERLINE + u'f', u'al', u's',
                u'<unk>', u'.'
            ])
Пример #8
0
 def get_tokenizer(self):
     return XLNetTokenizer.from_pretrained(self.tmpdirname)
Пример #9
0
    def setUp(self):
        super(XLNetTokenizationTest, self).setUp()

        # We have a SentencePiece fixture for testing
        tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)
        tokenizer.save_pretrained(self.tmpdirname)
Пример #10
0
 def get_tokenizer(self, **kwargs):
     return XLNetTokenizer.from_pretrained(self.tmpdirname, **kwargs)
Пример #11
0
    def test_full_tokenizer(self):
        tokenizer = XLNetTokenizer(SAMPLE_VOCAB, keep_accents=True)

        with TemporaryDirectory() as tmpdirname:
            tokenizer.save_pretrained(tmpdirname)

            input_text = "This is a test"
            output_text = "This is a test"

            create_and_check_tokenizer_commons(self, input_text, output_text,
                                               XLNetTokenizer, tmpdirname)

            tokens = tokenizer.tokenize("This is a test")
            self.assertListEqual(tokens, ["▁This", "▁is", "▁a", "▁t", "est"])

            self.assertListEqual(tokenizer.convert_tokens_to_ids(tokens),
                                 [285, 46, 10, 170, 382])

            tokens = tokenizer.tokenize(
                "I was born in 92000, and this is falsé.")
            self.assertListEqual(
                tokens,
                [
                    SPIECE_UNDERLINE + "I",
                    SPIECE_UNDERLINE + "was",
                    SPIECE_UNDERLINE + "b",
                    "or",
                    "n",
                    SPIECE_UNDERLINE + "in",
                    SPIECE_UNDERLINE + "",
                    "9",
                    "2",
                    "0",
                    "0",
                    "0",
                    ",",
                    SPIECE_UNDERLINE + "and",
                    SPIECE_UNDERLINE + "this",
                    SPIECE_UNDERLINE + "is",
                    SPIECE_UNDERLINE + "f",
                    "al",
                    "s",
                    "é",
                    ".",
                ],
            )
            ids = tokenizer.convert_tokens_to_ids(tokens)
            self.assertListEqual(
                ids,
                [
                    8, 21, 84, 55, 24, 19, 7, 0, 602, 347, 347, 347, 3, 12, 66,
                    46, 72, 80, 6, 0, 4
                ],
            )

            back_tokens = tokenizer.convert_ids_to_tokens(ids)
            self.assertListEqual(
                back_tokens,
                [
                    SPIECE_UNDERLINE + "I",
                    SPIECE_UNDERLINE + "was",
                    SPIECE_UNDERLINE + "b",
                    "or",
                    "n",
                    SPIECE_UNDERLINE + "in",
                    SPIECE_UNDERLINE + "",
                    "<unk>",
                    "2",
                    "0",
                    "0",
                    "0",
                    ",",
                    SPIECE_UNDERLINE + "and",
                    SPIECE_UNDERLINE + "this",
                    SPIECE_UNDERLINE + "is",
                    SPIECE_UNDERLINE + "f",
                    "al",
                    "s",
                    "<unk>",
                    ".",
                ],
            )
Пример #12
0
def preprocess():
    '''
    针对测试集数据预测
    :return:
    '''
    model = load_model(os.path.join(args.ROOT_DIR, args.output_dir),
                       args.model_type)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        n_gpu = 1

    model.to(device)

    processors = {"sentiment_analysis": SentiAnalysisProcessor}
    task_name = args.task_name.lower()
    processor = processors[task_name]()
    examples = processor.get_test_examples(args.data_dir)

    tokenizer = XLNetTokenizer.from_pretrained(
        os.path.join(args.ROOT_DIR, args.xlnet_model),
        do_lower_case=args.do_lower_case)
    mode = 'test'
    try:
        if mode == 'test':
            with open(os.path.join(args.data_dir, args.TEST_FEATURE_FILE),
                      'rb') as f:  # TRAIN_FEATURE_FILE
                features = pickle.load(f)
    except:
        features = convert_examples_to_features(
            examples,
            args.max_seq_length,
            args.split_num,
            tokenizer,
            mode=mode,
            cls_token_at_end=bool(args.model_type in ['xlnet']),
            # xlnet has a cls token at the end
            cls_token=tokenizer.cls_token,
            sep_token=tokenizer.sep_token,
            cls_token_segment_id=2 if args.model_type in ['xlnet'] else 0,
            pad_on_left=bool(
                args.model_type in ['xlnet']),  # pad on the left for xlnet
            pad_token_segment_id=4 if args.model_type in ['xlnet'] else 0)

    all_input_ids = torch.tensor(select_field(features, 'input_ids'),
                                 dtype=torch.long)
    all_input_mask = torch.tensor(select_field(features, 'input_mask'),
                                  dtype=torch.long)
    all_segment_ids = torch.tensor(select_field(features, 'segment_ids'),
                                   dtype=torch.long)

    #数据集
    dataset = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)
    sampler = SequentialSampler(dataset)  #作用近似shuffle
    dataloader = DataLoader(dataset,
                            sampler=sampler,
                            batch_size=args.per_gpu_train_batch_size *
                            max(1, n_gpu))
    model.eval()
    y_predicts = []
    for input_ids, input_mask, segment_ids in dataloader:
        input_ids = input_ids.to(device)
        input_mask = input_mask.to(device)
        segment_ids = segment_ids.to(device)

        with torch.no_grad():
            logits = model(input_ids=input_ids,
                           token_type_ids=segment_ids,
                           attention_mask=input_mask)[0]
            predicts = model.predict(logits)
        y_predicts.append(torch.from_numpy(predicts))
    eval_predicted = torch.cat(y_predicts, dim=0).cpu().numpy()

    df = pd.read_csv(os.path.join(args.data_dir, args.TEST_CORPUS_FILE))
    df['labels'] = eval_predicted
    df[['id', 'labels']].to_csv('./data/test_final.csv',
                                sep=',',
                                encoding='utf_8_sig',
                                header=True,
                                index=False)
Пример #13
0
def convert_split_doc_to_feature(doc,
                                 max_seq_length,
                                 split_num,
                                 cls_token_at_end=False,
                                 pad_on_left=False,
                                 cls_token='[CLS]',
                                 sep_token='[SEP]',
                                 pad_token=0,
                                 sequence_a_segment_id=0,
                                 sequence_b_segment_id=1,
                                 cls_token_segment_id=1,
                                 pad_token_segment_id=0,
                                 mask_padding_with_zero=True):
    '''
    将测试文本转化为测试特征
    :param doc: corpus for test
    :return: feature for predicting
    '''
    tokenizer = XLNetTokenizer.from_pretrained(
        os.path.join(args.ROOT_DIR, args.xlnet_model),
        do_lower_case=args.do_lower_case)
    tokens_a = tokenizer.tokenize(doc)

    skip_len = len(tokens_a) / split_num
    choices_features = []

    for i in range(split_num):
        context_tokens_choice = tokens_a[int(i * skip_len):int((i + 1) *
                                                               skip_len)]

        tokens_b = None

        # Account for [CLS] and [SEP] with "- 2"
        if len(context_tokens_choice) > max_seq_length - 2:
            context_tokens_choice = context_tokens_choice[:(max_seq_length -
                                                            2)]

        tokens = context_tokens_choice + [sep_token]
        segment_ids = [sequence_a_segment_id] * len(tokens)

        if tokens_b:
            tokens += tokens_b + [sep_token]
            segment_ids += [sequence_b_segment_id] * (len(tokens_b) + 1)

        if cls_token_at_end:
            tokens = tokens + [cls_token]
            segment_ids = segment_ids + [cls_token_segment_id]
        else:
            tokens = [cls_token] + tokens
            segment_ids = [cls_token_segment_id] + segment_ids

        input_ids = tokenizer.convert_tokens_to_ids(tokens)

        # The mask has 1 for real tokens and 0 for padding tokens. Only real
        # tokens are attended to.
        input_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)

        # Zero-pad up to the sequence length.
        padding_length = max_seq_length - len(input_ids)
        if pad_on_left:
            input_ids = ([pad_token] * padding_length) + input_ids
            input_mask = ([0 if mask_padding_with_zero else 1] *
                          padding_length) + input_mask
            segment_ids = ([pad_token_segment_id] *
                           padding_length) + segment_ids
        else:
            input_ids = input_ids + ([pad_token] * padding_length)
            input_mask = input_mask + ([0 if mask_padding_with_zero else 1] *
                                       padding_length)
            segment_ids = segment_ids + ([pad_token_segment_id] *
                                         padding_length)

        assert len(input_ids) == max_seq_length
        assert len(input_mask) == max_seq_length
        assert len(segment_ids) == max_seq_length

        choices_features.append({
            'input_ids': input_ids,
            'input_mask': input_mask,
            'segment_ids': segment_ids
        })

        # logger.info("*** Example ***")
        # logger.info("tokens: {}".format(' '.join(tokens).replace('\u2581', '_')))
        # logger.info("input_ids: {}".format(' '.join(map(str, input_ids))))
        # logger.info("input_mask: {}".format(' '.join(map(str, input_mask))))
        # logger.info("segment_ids: {}".format(' '.join(map(str, segment_ids))))

    return choices_features
Пример #14
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir", default=None, type=str, required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--model_type", default=None, type=str, required=True,
                        help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument("--model_name_or_path", default=None, type=str, required=True,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
    parser.add_argument("--meta_path", default=None, type=str, required=False,
                        help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(ALL_MODELS))
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--config_name", default="", type=str,
                        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument("--tokenizer_name", default="", type=str,
                        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument("--cache_dir", default="", type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length", default=128, type=int,
                        help="The maximum total input sequence length after tokenization. Sequences longer "
                             "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_test", action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval", action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--evaluate_during_training", action='store_true',
                        help="Rul evaluation during training at each logging step.")
    parser.add_argument("--do_lower_case", action='store_true',
                        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--split_num", default=3, type=int,
                        help="text split")
    parser.add_argument("--per_gpu_eval_batch_size", default=8, type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument('--gradient_accumulation_steps', type=int, default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--learning_rate", default=5e-5, type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay", default=0.0, type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon", default=1e-8, type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm", default=1.0, type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--max_steps", default=-1, type=int,
                        help="If > 0: set total number of training steps to perform. Override num_train_epochs.")
    parser.add_argument("--eval_steps", default=-1, type=int,
                        help="")
    parser.add_argument("--train_steps", default=-1, type=int,
                        help="")
    parser.add_argument("--report_steps", default=-1, type=int,
                        help="")
    parser.add_argument("--warmup_steps", default=0, type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--lstm_hidden_size", default=300, type=int,
                        help="")
    parser.add_argument("--lstm_layers", default=2, type=int,
                        help="")
    parser.add_argument("--lstm_dropout", default=0.5, type=float,
                        help="")    
    parser.add_argument('--logging_steps', type=int, default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps', type=int, default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument("--eval_all_checkpoints", action='store_true',
                        help="Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number")
    parser.add_argument("--no_cuda", action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir', action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument('--overwrite_cache', action='store_true',
                        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed', type=int, default=42,
                        help="random seed for initialization")

    parser.add_argument('--fp16', action='store_true',
                        help="Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit")
    parser.add_argument('--fp16_opt_level', type=str, default='O1',
                        help="For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
                             "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank", type=int, default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip', type=str, default='', help="For distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="For distant debugging.")
    args = parser.parse_args()
    
    

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device
    
    
    # Setup logging
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning("Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
                    args.local_rank, device, args.n_gpu, bool(args.local_rank != -1), args.fp16)
    
    # Set seed
    set_seed(args)


    try:
        os.makedirs(args.output_dir)
    except:
        pass
    
    tokenizer = XLNetTokenizer.from_pretrained(args.model_name_or_path, do_lower_case=args.do_lower_case)
    
    
    
    config = XLNetConfig.from_pretrained(args.model_name_or_path, num_labels=3)
    
    # Prepare model
    model = XLNetForSequenceClassification.from_pretrained(args.model_name_or_path,args,config=config)


        
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    if args.do_train:

        # Prepare data loader

        train_examples = read_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
        train_features = convert_examples_to_features(
            train_examples, tokenizer, args.max_seq_length,args.split_num, True)
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size//args.gradient_accumulation_steps)

        num_train_optimization_steps =  args.train_steps


        # Prepare optimizer

        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]

        optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=5000)
        
        global_step = 0

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        
        best_acc=0
        model.train()
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0        
        bar = tqdm(range(num_train_optimization_steps),total=num_train_optimization_steps)
        train_dataloader=cycle(train_dataloader)

        
        for step in bar:
            batch = next(train_dataloader)
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            loss = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
            if args.n_gpu > 1:
                loss = loss.mean() # mean() to average on multi-gpu.
            if args.fp16 and args.loss_scale != 1.0:
                loss = loss * args.loss_scale
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            tr_loss += loss.item()
            train_loss=round(tr_loss*args.gradient_accumulation_steps/(nb_tr_steps+1),4)
            bar.set_description("loss {}".format(train_loss))
            nb_tr_examples += input_ids.size(0)
            nb_tr_steps += 1

            if args.fp16:
                optimizer.backward(loss)
            else:

                loss.backward()

            if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                global_step += 1


            if (step + 1) %(args.eval_steps*args.gradient_accumulation_steps)==0:
                tr_loss = 0
                nb_tr_examples, nb_tr_steps = 0, 0 
                logger.info("***** Report result *****")
                logger.info("  %s = %s", 'global_step', str(global_step))
                logger.info("  %s = %s", 'train loss', str(train_loss))


            if args.do_eval and (step + 1) %(args.eval_steps*args.gradient_accumulation_steps)==0:
                for file in ['dev.csv']:
                    inference_labels=[]
                    gold_labels=[]
                    inference_logits=[]
                    eval_examples = read_examples(os.path.join(args.data_dir, file), is_training = True)
                    eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,args.split_num,False)
                    all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
                    all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
                    all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
                    all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)                   


                    eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label)
                        
                    logger.info("***** Running evaluation *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)  
                        
                    # Run prediction for full data
                    eval_sampler = SequentialSampler(eval_data)
                    eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

                    model.eval()
                    eval_loss, eval_accuracy = 0, 0
                    nb_eval_steps, nb_eval_examples = 0, 0
                    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)


                        with torch.no_grad():
                            tmp_eval_loss= model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask, labels=label_ids)
                            logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)

                        logits = logits.detach().cpu().numpy()
                        label_ids = label_ids.to('cpu').numpy()
                        inference_labels.append(np.argmax(logits, axis=1))
                        gold_labels.append(label_ids)
                        inference_logits.append(logits)
                        eval_loss += tmp_eval_loss.mean().item()
                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1
                        
                    gold_labels=np.concatenate(gold_labels,0) 
                    inference_logits=np.concatenate(inference_logits,0)
                    model.train()
                    eval_loss = eval_loss / nb_eval_steps
                    eval_accuracy = accuracy(inference_logits, gold_labels)

                    result = {'eval_loss': eval_loss,
                              'eval_F1': eval_accuracy,
                              'global_step': global_step,
                              'loss': train_loss}

                    output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
                    with open(output_eval_file, "a") as writer:
                        for key in sorted(result.keys()):
                            logger.info("  %s = %s", key, str(result[key]))
                            writer.write("%s = %s\n" % (key, str(result[key])))
                        writer.write('*'*80)
                        writer.write('\n')
                    if eval_accuracy>best_acc and 'dev' in file:
                        print("="*80)
                        print("Best F1",eval_accuracy)
                        print("Saving Model......")
                        best_acc=eval_accuracy
                        # Save a trained model
                        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(), output_model_file)
                        print("="*80)
                    else:
                        print("="*80)
    if args.do_test:
        del model
        gc.collect()
        args.do_train=False
        model = XLNetForSequenceClassification.from_pretrained(os.path.join(args.output_dir, "pytorch_model.bin"),args,config=config)
        if args.fp16:
            model.half()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

            model = DDP(model)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)        
        
        
        for file,flag in [('dev.csv','dev'),('test.csv','test')]:
            inference_labels=[]
            gold_labels=[]
            eval_examples = read_examples(os.path.join(args.data_dir, file), is_training = False)
            eval_features = convert_examples_to_features(eval_examples, tokenizer, args.max_seq_length,args.split_num,False)
            all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
            all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
            all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
            all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)                           


            eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,all_label)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask).detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                inference_labels.append(logits)
                gold_labels.append(label_ids)
            gold_labels=np.concatenate(gold_labels,0)
            logits=np.concatenate(inference_labels,0)
            print(flag, accuracy(logits, gold_labels))
            if flag=='test':
                df=pd.read_csv(os.path.join(args.data_dir, file))
                df['label_0']=logits[:,0]
                df['label_1']=logits[:,1]
                df['label_2']=logits[:,2]
                df[['id','label_0','label_1','label_2']].to_csv(os.path.join(args.output_dir, "sub.csv"),index=False)