示例#1
0
    def load_model(epoch):

        nonlocal model

        if output_dir is None:
            return False

        if model_file is None:
            output_model_file = os.path.join(
                output_dir, "pytorch_model_ep{}.bin".format(epoch))
        else:
            output_model_file = os.path.join(output_dir, model_file)

        if not os.path.exists(output_model_file):
            logger.info(
                "Stopping at epoch {} since model file is missing ({}).".
                format(ep, output_model_file))
            return False

        logger.info("Loading epoch {} from disk...".format(epoch))
        model = BertForSequenceClassification(
            config, num_labels=processor.num_labels())

        # noinspection PyUnresolvedReferences
        model.load_state_dict(
            torch.load(output_model_file,
                       map_location=lambda storage, loc: storage
                       if no_cuda else None))

        # noinspection PyUnresolvedReferences
        model.to(device)

        return True
示例#2
0
    def initialize(self, context):
        """
        Initialize model. This will be called during model loading time
        :param context: Initial context contains model server system properties.
        :return:
        """
        # ModelHandler.LOGGER.critical("initializing model: %s - %s", context, type(context))
        try:
            MODEL_DIR = "/models/intents"

            # device
            device_name = "cuda" if torch.cuda.is_available() else "cpu"
            device = torch.device(device_name)
            # label encoder
            labelencoder = preprocessing.LabelEncoder()
            labelencoder.classes_ = np.load(
                os.path.join(MODEL_DIR, 'classes.npy'))

            # model config
            config = BertConfig(os.path.join(MODEL_DIR, 'bert_config.json'))

            # model
            model = BertForSequenceClassification(config,
                                                  num_labels=len(
                                                      labelencoder.classes_))
            model.load_state_dict(
                torch.load(os.path.join(MODEL_DIR, 'pytorch_model.bin'),
                           map_location="cpu"))
            model.to(device)
            model.eval()

            self.labelencoder = labelencoder
            self.model = model
            self.device = device
            self.model_batch_size = 32
            self.softmax = torch.nn.Softmax(dim=-1)
            self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
            self.batch_size = context.system_properties["batch_size"]
            ModelHandler.STREAM_LOGGER.critical("initialized: %s - %s - %s",
                                                context, device_name,
                                                context.system_properties)
        except Exception as e:
            ModelHandler.STREAM_LOGGER.critical(
                "exeption in initialization: %s", str(e))
示例#3
0
    def init_model(self, load: bool = False) -> BertForSequenceClassification:
        """
        Initialize BertForSequenceClassification model
        :param load: If true, load custom pre-trained model
        :return:
        """
        if load:
            logger.info("\n==> 🚀 Loading model:")
            output_model_file = os.path.join(self.output_dir, WEIGHTS_NAME)
            output_config_file = os.path.join(self.output_dir, CONFIG_NAME)
            config = BertConfig(output_config_file)
            model = BertForSequenceClassification(config,
                                                  num_labels=self.num_labels)
            model.load_state_dict(torch.load(output_model_file))

        else:
            logger.info("\n==> 🚀 Initializing model:")
            model = BertForSequenceClassification.from_pretrained(
                self.bert_model,
                cache_dir=PYTORCH_PRETRAINED_BERT_CACHE,
                num_labels=self.num_labels)

        model.to(self.device)
        return model
示例#4
0
def create_model(args, dataset, train=True):
    print("[*] Create model.")

    global model
    if train:
        model = BertForSequenceClassification.from_pretrained(BERT,
                                                              num_labels=5)
    else:
        if BERT == 'bert-large-uncased':
            config = BertConfig.from_json_file("uncase_model")
        else:
            config = BertConfig.from_json_file("case_model")
        model = BertForSequenceClassification(config, num_labels=5)
    # for i in model.bert.named_parameters():
    #     i[1].requires_grad=False

    model = model.to(device)
    # print(model)

    param_optimizer = list(model.named_parameters())

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if train:
        num_train_optimization_steps = int(
            len(dataset["train"]) / args.batch_size /
            args.gradient_accumulation_steps) * args.epochs
        global optimizer
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.lr_rate,
                             warmup=0.1,
                             t_total=num_train_optimization_steps)

    # optimizer = optim.Adam(model.parameters(),
    #                        lr=args.lr_rate) # , betas=(0.9, 0.999), weight_decay=1e-3)
    return
示例#5
0
class BertPredict(object):
    def __init__(self, args):
        self.args = args

        if self.args.local_rank == -1 or self.args.no_cuda:
            self.device = torch.device("cuda" if torch.cuda.is_available()
                                       and not self.args.no_cuda else "cpu")
            n_gpu = torch.cuda.device_count()
        else:
            torch.cuda.set_device(self.args.local_rank)
            self.device = torch.device("cuda", self.args.local_rank)
            n_gpu = 1
            # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
            torch.distributed.init_process_group(backend='nccl')
        logger.info("device: {} n_gpu: {}, distributed training: {}".format(
            self.device, n_gpu, bool(self.args.local_rank != -1)))

        random.seed(self.args.seed)
        np.random.seed(self.args.seed)
        torch.manual_seed(self.args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(self.args.seed)

        processor = OffenseEvalData()
        self.label_list = processor.get_labels()
        self.num_labels = len(self.label_list)

        self.tokenizer = BertTokenizer.from_pretrained(self.args.bert_model,
                                                       do_lower_case=True)

        # Load a trained model and config that you have fine-tuned
        output_model_file = os.path.join(self.args.bert_model_dir,
                                         WEIGHTS_NAME)
        output_config_file = os.path.join(self.args.bert_model_dir,
                                          CONFIG_NAME)
        config = BertConfig(output_config_file)
        self.model = BertForSequenceClassification(config,
                                                   num_labels=self.num_labels)
        self.model.load_state_dict(torch.load(output_model_file))
        self.model.to(self.device)
        self.model.eval()
        self.label_map = {i: label for i, label in enumerate(self.label_list)}

    def predict_one(self, test_input):

        eval_examples = [test_input]
        eval_features = convert_examples_to_features(eval_examples,
                                                     self.label_list,
                                                     self.args.max_seq_length,
                                                     self.tokenizer)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)

        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=1)

        preds = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(self.device)
            input_mask = input_mask.to(self.device)
            segment_ids = segment_ids.to(self.device)

            with torch.no_grad():
                logits = self.model(input_ids,
                                    segment_ids,
                                    input_mask,
                                    labels=None)

            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0],
                                     logits.detach().cpu().numpy(),
                                     axis=0)

        preds = preds[0]
        preds = np.argmax(preds, axis=1)
        return self.label_map[preds[0]]
示例#6
0
def main():
    parser = make_arg_parser()
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "cli": CLIProcessor,
    }

    num_labels_task = {
        "cli": 7,
    }

    # Check whether bert_model_or_config_file is a file or directory
    if os.path.isdir(args.bert_model_or_config_file):
        pretrained = True
        targets = [WEIGHTS_NAME, CONFIG_NAME, "tokenizer.pkl"]
        for t in targets:
            path = os.path.join(args.bert_model_or_config_file, t)
            if not os.path.exists(path):
                msg = "File '{}' not found".format(path)
                raise ValueError(msg)
        fp = os.path.join(args.bert_model_or_config_file, CONFIG_NAME)
        config = BertConfig(fp)
    else:
        pretrained = False
        config = BertConfig(args.bert_model_or_config_file)

    # What GPUs do we use?
    if args.num_gpus == -1:
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        n_gpu = torch.cuda.device_count()
        device_ids = None
    else:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and args.num_gpus > 0 else "cpu")
        n_gpu = args.num_gpus
        if n_gpu > 1:
            device_ids = list(range(n_gpu))
    if args.local_rank != -1:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    # Check some other args
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
    if not args.do_train and not args.do_eval and not args.do_predict:
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_predict` must be True."
        )

    # Seed RNGs
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # Prepare output directory
    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    # Get training data
    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Make tokenizer
    if pretrained:
        fp = os.path.join(args.bert_model_or_config_file, "tokenizer.pkl")
        with open(fp, "rb") as f:
            tokenizer = pickle.load(f)
    else:
        tokenizer = CuneiformCharTokenizer(
            training_data=[x.text_a for x in train_examples])
        tokenizer.trim_vocab(config.min_freq)
        # Adapt vocab size in config
        config.vocab_size = len(tokenizer.vocab)
    print("Size of vocab: {}".format(len(tokenizer.vocab)))

    # Prepare model
    if pretrained:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model_or_config_file, num_labels=num_labels)
    else:
        model = BertForSequenceClassification(config, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=device_ids)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    # Get dev data
    if args.do_eval:
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

    # Prepare for training
    global_step = 0
    nb_tr_steps = 0
    total_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        # Prepare log file
        output_log_file = os.path.join(args.output_dir, "training_log.txt")
        with open(output_log_file, "w") as f:
            if args.do_eval:
                f.write("Steps\tTrainLoss\tValLoss\tValAccuracy\tValFScore\n")
            else:
                f.write("Steps\tTrainLoss\n")

        best_val_score = float("-inf")
        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
            avg_loss = tr_loss / nb_tr_examples
            total_tr_steps += nb_tr_steps
            log_data = [str(total_tr_steps), "{:.5f}".format(avg_loss)]

            # Validate
            if args.do_eval and (args.local_rank == -1
                                 or torch.distributed.get_rank() == 0):
                predictions, eval_loss, eval_accuracy, fscore = evaluate(
                    model, eval_dataloader, device)
                log_data.append("{:.5f}".format(eval_loss))
                log_data.append("{:.5f}".format(eval_accuracy))
                log_data.append("{:.5f}".format(fscore))
                # Check if score has improved
                if fscore > best_val_score:
                    best_val_score = fscore
                    save_model(model, tokenizer, args.output_dir)
            else:
                # If we can't validate, we save model at each epoch
                save_model(model, tokenizer, args.output_dir)

            # Log
            with open(output_log_file, "a") as f:
                f.write("\t".join(log_data) + "\n")

    # Load model
    if args.do_train:
        # Load model we just fine-tuned
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        output_tokenizer_file = os.path.join(args.output_dir, "tokenizer.pkl")
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
        with open(output_tokenizer_file, "rb") as f:
            tokenizer = pickle.load(f)
    else:
        # Load a model you fine-tuned previously
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model_or_config_file, num_labels=num_labels)
    model.to(device)

    # Evaluate model on validation data
    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        predictions, eval_loss, eval_accuracy, fscore = evaluate(
            model, eval_dataloader, device)
        loss = avg_loss if args.do_train else None
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'eval_fscore': fscore,
            'global_step': global_step,
            'loss': loss
        }

        # Write evaluation results
        output_eval_file = os.path.join(args.output_dir, "dev_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

        # Write predictions
        output_pred_file = os.path.join(args.output_dir, "dev_pred.txt")
        with open(output_pred_file, "w", encoding="utf-8") as writer:
            for label_id in predictions:
                label = label_list[label_id]
                writer.write(label + "\n")

    # Predict labels of test set
    if args.do_predict:
        test_examples = processor.get_test_examples(args.data_dir)
        test_features = convert_examples_to_features(test_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)
        test_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)

        logger.info("***** Running prediction *****")
        logger.info("  Num examples = %d", len(test_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        predictions = predict(model, test_dataloader, device)

        # Write predictions
        output_pred_file = os.path.join(args.output_dir, "test_pred.txt")
        with open(output_pred_file, "w", encoding="utf-8") as writer:
            for label_id in predictions:
                label = label_list[label_id]
                writer.write(label + "\n")
示例#7
0
def bert():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=5,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}".format(
        device, n_gpu, bool(args.local_rank != -1)))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_test:
        raise ValueError(
            "At least one of `do_train` or `do_eval` or `do_test` must be True."
        )

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    processor = OffenseEvalData()
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=True)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_examples = None
        num_train_optimization_steps = None
        if args.do_train:
            train_examples = processor.get_train_examples(args.data_dir)
            num_train_optimization_steps = int(
                len(train_examples) / args.train_batch_size /
                args.gradient_accumulation_steps) * args.num_train_epochs
            if args.local_rank != -1:
                num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
                )

        # Prepare model
        cache_dir = args.cache_dir if args.cache_dir else os.path.join(
            str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
                args.local_rank))
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, cache_dir=cache_dir, num_labels=num_labels)

        model.to(device)

        # Prepare optimizer
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)

        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                loss_fct = CrossEntropyLoss()
                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
        model.to(device)

    if args.do_eval or args.do_test:

        # Load a trained model and config that you have fine-tuned
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
        model.to(device)

        eval_examples = processor.get_dev_examples(
            args.data_dir) if args.do_eval else processor.get_test_examples(
                args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)

        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            # create eval loss and other metric required by the task
            loss_fct = CrossEntropyLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                     label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0],
                                     logits.detach().cpu().numpy(),
                                     axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        preds = np.argmax(preds, axis=1)
        result = compute_metrics(preds, all_label_ids.numpy())
        loss = tr_loss / nb_tr_steps if args.do_train else None

        result['eval_loss'] = eval_loss
        result['global_step'] = global_step
        result['loss'] = loss

        output_eval_file = os.path.join(
            args.output_dir,
            "eval_results.txt" if args.do_eval else "test_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

        if args.do_test:
            output_eval_file = os.path.join(
                args.output_dir,
                "eval_results.txt" if args.do_eval else "test_submissions.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Test submission file *****")
                label_map = {i: label for i, label in enumerate(label_list)}
                for test, pred in zip(eval_examples, preds):
                    writer.write("%s,%s\n" % (test.guid, label_map[pred]))
示例#8
0
class Rewarder():
    def __init__(self, args, tokenizer):

        self.args = args

        self.nli_tokenizer = BertTokenizer.from_pretrained(
            args.bert_model,
            do_lower_case=args.do_lower_case,
            cache_dir='.pytorch_pretrained_bert')
        self.output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        self.output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        self.nli_config = BertConfig(self.output_config_file)
        self.nli_model = BertForSequenceClassification(self.nli_config,
                                                       num_labels=3)
        self.nli_model.load_state_dict(
            torch.load(self.output_model_file,
                       map_location=torch.device('cpu')))
        self.nli_model.to(args.device)
        self.nli_model.eval()

        if args.nli_uu_reward or args.nli_allres_reward:
            uu_output_config_file = os.path.join(args.uu_output_dir,
                                                 CONFIG_NAME)
            uu_output_model_file = os.path.join(args.uu_output_dir,
                                                WEIGHTS_NAME)
            self.uu_nli_config = BertConfig(uu_output_config_file)
            self.uu_nli_model = BertForSequenceClassification(
                self.uu_nli_config, num_labels=3)
            self.uu_nli_model.load_state_dict(
                torch.load(uu_output_model_file,
                           map_location=torch.device('cpu')))
            self.uu_nli_model.to(args.device)
            self.uu_nli_model.eval()

        bert_emb_modelpath = "bert-base-uncased"
        self.bert_emb_tokenizer = BertTokenizer.from_pretrained(
            bert_emb_modelpath, cache_dir='.pytorch_pretrained_bert')
        self.bert_emb_model = BertModel.from_pretrained(
            bert_emb_modelpath,
            cache_dir='.pytorch_pretrained_bert').to(args.device)
        self.bert_emb_model.eval()

        self.tokenizer = tokenizer

        if args.lm_reward:
            lm_model_path = 'openai-gpt'
            lm_output_dir = 'language-quality-subreward/gpt_output'
            lm_special_tokens = ['_start_', '_delimiter_', '_classify_']
            # Load pre-trained model (weights)
            with torch.no_grad():
                lm_output_config_file = os.path.join(lm_output_dir,
                                                     CONFIG_NAME)
                lm_config = OpenAIGPTConfig(lm_output_config_file)

                lm_output_model_file = os.path.join(lm_output_dir,
                                                    WEIGHTS_NAME)
                #lm_model_state_dict = torch.load(lm_output_model_file)
                lm_model_state_dict = torch.load(lm_output_model_file,
                                                 map_location='cpu')
                self.lm_model = OpenAIGPTLMHeadModel(lm_config)
                self.lm_model.load_state_dict(lm_model_state_dict)

                # Load pre-trained model tokenizer (vocabulary)
                self.lm_tokenizer = OpenAIGPTTokenizer.from_pretrained(
                    lm_model_path,
                    special_tokens=lm_special_tokens,
                    cache_dir='.pytorch_pretrained_bert')

            self.special_tokens_ids = list(
                self.lm_tokenizer.convert_tokens_to_ids(token)
                for token in lm_special_tokens)
            self.lm_model.to(args.device)
            self.lm_model.eval()

    def persona_rewarder(self, response, rl_train_personas_org):
        # cancat all the personas
        '''
        personas_org_chain = [''.join(rl_train_personas_org)]
        reward = nli_engine(response, personas_org_chain, nli_tokenizer, nli_model)[0]
        '''
        scores = nli_engine(response, rl_train_personas_org,
                            self.nli_tokenizer, self.nli_model)
        current_persona_reward_0 = (
            (sum(scores) / len(rl_train_personas_org)) + 2) / 3
        current_persona_reward = current_persona_reward_0 * self.args.nli_weight
        logger.info('persona_reward before/after weighting = %f/%f' %
                    (current_persona_reward_0, current_persona_reward))
        return current_persona_reward

    def nli_allres_rewarder(self, response, history):
        # history_chain = list(chain(*history))
        # history_text = tokenizer.decode(history_chain, skip_special_tokens=True, clean_up_tokenization_spaces=False)
        pre_responses = []
        for i in range(-len(history), 0):
            if i % 2 == 0:
                current_text = self.tokenizer.decode(
                    history[i],
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False)
                pre_responses.append(current_text)
        response_scores = nli_engine(response, pre_responses,
                                     self.nli_tokenizer, self.nli_model)
        if response_scores == []:
            current_response_reward = 0.5  # TODO: test if single allres will work
        else:
            current_response_reward = sum(response_scores) / len(
                response_scores)
        current_response_reward_0 = (current_response_reward + 2) / 3
        current_response_reward = current_response_reward * self.args.nli_allres_weight
        logger.info('allres_reward before/after weighting = %f/%f' %
                    (current_response_reward_0, current_response_reward))
        return current_response_reward

    def cos_sim_bert_rewarder(self, response, history):
        pre_utt = history[-1]
        pre_utt_text = self.tokenizer.decode(
            pre_utt,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=False)
        pre_utt_vec = bert_vector(pre_utt_text, self.bert_emb_tokenizer,
                                  self.bert_emb_model, self.args)
        response_vec = bert_vector(response, self.bert_emb_tokenizer,
                                   self.bert_emb_model, self.args)
        cos_sim_bert_score = cosine_similarity(pre_utt_vec.reshape(1, -1),
                                               response_vec.reshape(1,
                                                                    -1))[0][0]
        current_cos_sim_bert_reward = cos_sim_bert_score * self.args.cos_sim_bert_weight
        logger.info('cos_sim_bert before/after weighting = %f/%f' %
                    (cos_sim_bert_score, current_cos_sim_bert_reward))
        return current_cos_sim_bert_reward

    def intern_rep_rewarder(self, response):
        # response = 'i\'m 16 years years years years years old bye bye.'
        # intrep_word
        response_tok = response.split()
        intrep_1gram = intrep_frac(response_tok)
        # intrep_2gram
        response_tok_2gram = get_ngrams(response, 2)
        intrep_2gram = intrep_frac(response_tok_2gram)
        # intrep_3gram
        response_tok_3gram = get_ngrams(response, 3)
        intrep_3gram = intrep_frac(response_tok_3gram)
        current_intern_rep_reward = (
            1 - intrep_1gram
        ) * self.args.intern_rep_weight  # TODO: How to design this reward?
        logger.info('intern_rep before/after weighting = %f/%f' %
                    ((1 - intrep_1gram), current_intern_rep_reward))
        return current_intern_rep_reward

    def extern_rep_rewarder(self, response, history):
        pre_responses = []
        for i in range(-len(history), 0):
            if i % 2 == 0:
                current_text = self.tokenizer.decode(
                    history[i],
                    skip_special_tokens=True,
                    clean_up_tokenization_spaces=False)
                pre_responses.append(current_text)

        # extrep_word
        response_tok = response.split()
        prev_tok = [s.split() for s in pre_responses]  # list of list of ints
        prev_tok = list(set(flatten(prev_tok)))  # list of ints, no duplicates
        extrep_1gram = extrep_frac(response_tok, prev_tok)
        # extrep_2gram
        response_tok_2gram = get_ngrams(response, 2)
        prev_2grams = [get_ngrams(prev, 2)
                       for prev in pre_responses]  # list of list of strings
        prev_2grams = list(set(
            flatten(prev_2grams)))  # list of strings, no duplicates
        extrep_2gram = extrep_frac(response_tok_2gram, prev_2grams)
        # extrep_3gram
        response_tok_3gram = get_ngrams(response, 3)
        prev_3grams = [get_ngrams(prev, 3)
                       for prev in pre_responses]  # list of list of strings
        prev_3grams = list(set(
            flatten(prev_3grams)))  # list of strings, no duplicates
        extrep_3gram = extrep_frac(response_tok_3gram, prev_3grams)

        current_extern_rep_reward = 0  # TODO: How to design this reward?
        logger.info('extern_rep before/after weighting = %f/%f' %
                    (current_extern_rep_reward, current_extern_rep_reward))
        return current_extern_rep_reward

    def lm_rewarder(self, response):
        lm_tokenize_input = self.lm_tokenizer.tokenize(response)
        # lm_tensor_input = torch.tensor([lm_tokenizer.convert_tokens_to_ids(lm_tokenize_input)]).to(args.device)
        lm_tensor_input = torch.tensor(
            [[self.special_tokens_ids[0]] +
             self.lm_tokenizer.convert_tokens_to_ids(lm_tokenize_input) +
             [self.special_tokens_ids[-1]]]).to(self.args.device)
        lm_loss = self.lm_model(lm_tensor_input, lm_labels=lm_tensor_input)
        # lm_ppl = math.exp(lm_loss.item())
        nll = -lm_loss.item()
        if nll < -4:
            nll = -4
        current_lm_score = (nll + 4) / 4
        current_lm_reward = current_lm_score * self.args.lm_weight  # TODO: 1/lm_ppl?
        logger.info('lm_reward before/after weighting = %f/%f' %
                    (current_lm_score, current_lm_reward))
        return current_lm_reward

    def qback_rewarder(self, response):
        response_tok = response.split()
        num_in_list = len([w for w in response_tok if w in QN_WORDS])
        current_qback_reward = (num_in_list /
                                len(response_tok)) * self.args.qback_weight
        logger.info('qback_reward before/after weighting = %f/%f' %
                    ((num_in_list / len(response_tok)), current_qback_reward))
        return current_qback_reward

    def get_reward(self, response, rl_train_personas_org, history):

        R = {
            'reward': 0,
            'persona_reward': 0,
            'response_reward': 0,
            'uu_reward': 0,
            'cos_sim_bert_reward': 0,
            'intern_rep_reward': 0,
            'extern_rep_reward': 0,
            'lm_reward': 0,
            'qback_reward': 0,
            'f1_reward': 0,
            'bleu_reward': 0
        }

        if self.args.nli_reward:
            R['persona_reward'] = self.persona_rewarder(
                response, rl_train_personas_org)

        if self.args.nli_allres_reward:
            R['response_reward'] = self.nli_allres_rewarder(response, history)

        if self.args.nli_uu_reward:
            R['uu_reward'] = self.nli_uu_rewarder(response, history)

        if self.args.cos_sim_bert_reward:
            R['cos_sim_bert_reward'] = self.cos_sim_bert_rewarder(
                response, history)

        if self.args.intern_rep_reward:
            R['intern_rep_reward'] = self.intern_rep_rewarder(response)

        if self.args.extern_rep_reward:
            R['extern_rep_reward'] = self.extern_rep_rewarder(
                response, history)

        if self.args.lm_reward:
            R['lm_reward'] = self.lm_rewarder(response)

        if self.args.qback_reward:
            R['qback_reward'] = self.qback_rewarder(response)

        R['reward'] = R['persona_reward'] + \
             R['response_reward'] + \
             R['uu_reward'] + \
             R['cos_sim_bert_reward']+ \
             R['intern_rep_reward'] + \
             R['extern_rep_reward'] + \
             R['lm_reward'] + \
             R['qback_reward'] + \
             R['f1_reward'] + \
             R['bleu_reward']

        return R
示例#9
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_predict",
                        action='store_true',
                        help="Whether to run predict on the test set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--predict_batch_size",
                        default=1,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "ques_cate": QuescateProcessor,
    }

    num_labels_task = {
        "cola": 2,
        "sst-2": 2,
        "mnli": 3,
        "mrpc": 2,
        "ques_cate": 3,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_eval` or `do_predict` must be True."
        )

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()
    """
    Before to add a new classification task, we should register task name to dict processors and num_labels_task.
    """
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()  # classification function
    num_labels = num_labels_task[task_name]  # category nums,
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    # train_examples = None
    num_train_optimization_steps = None

    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model (load), download from s3
    if args.do_train or args.do_eval:
        cache_dir = args.cache_dir if args.cache_dir else os.path.join(
            str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
                args.local_rank))
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    if args.do_predict:
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))

        model.to(device)
    elif not args.do_train and not args.do_predict:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)
        model.to(device)
    """
    To evaluation
    """
    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
    """
    To predict,
        one by one to predict, i.e., one time only has one sample.
    """
    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
        predict_examples = processor.get_test_examples(args.data_dir)
        num_actual_predict_examples = len(predict_examples)
        """
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_id=label_id
        """
        predict_features = convert_examples_to_features(
            predict_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(predict_examples))
        logger.info("  Batch size = %d", args.predict_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in predict_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in predict_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(
            [f.segment_ids for f in predict_features], dtype=torch.long)
        # all_label_ids = torch.tensor([f.label_id for f in predict_features], dtype=torch.long)
        predict_data = TensorDataset(all_input_ids, all_input_mask,
                                     all_segment_ids)
        # Run prediction for full data
        predict_sampler = SequentialSampler(predict_data)
        predict_dataloader = DataLoader(predict_data,
                                        sampler=predict_sampler,
                                        batch_size=args.predict_batch_size)

        model.eval()
        predict = []
        for input_ids, input_mask, segment_ids in tqdm(predict_dataloader,
                                                       desc="Predicting"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            """
            batch_size=8
            
            type(logits) = <class 'numpy.ndarray'>
            logits:
                [[-0.69838923  0.27036643  0.5943373 ]
                 [-0.84512466  0.23943791  0.5472788 ]
                 [-0.4465914  -0.60343146 -0.8313097 ]
                 [-0.52020323 -0.475485   -0.8743459 ]
                 [-0.66284615  0.30615643  0.62117684]
                 [-0.6683669   0.27725238  0.572317  ]
                 [-0.7646524   0.26856643  0.5333996 ]
                 [-0.73449135  0.259271    0.5099745 ]]
            
            softmax to classification
                >>> a=np.array([[0.334,0.889,-0.123],[0.332,0.976,-0.543]])
                >>> 
                >>> aa=torch.tensor(a)
                >>> aa
                tensor([[ 0.3340,  0.8890, -0.1230],
                        [ 0.3320,  0.9760, -0.5430]], dtype=torch.float64)
                >>> 
                >>> print(torch.nn.functional.softmax(aa, dim=1))
                tensor([[0.2963, 0.5161, 0.1876],
                        [0.3011, 0.5734, 0.1255]], dtype=torch.float64)
                >>> print(torch.nn.functional.softmax(aa, dim=0))
                tensor([[0.5005, 0.4783, 0.6035],
                        [0.4995, 0.5217, 0.3965]], dtype=torch.float64)
                >>> print(torch.nn.functional.softmax(aa, dim=-1))
                tensor([[0.2963, 0.5161, 0.1876],
                        [0.3011, 0.5734, 0.1255]], dtype=torch.float64)
                >>> aa.shape
                torch.Size([2, 3])
            
            To acquire the most prob elem.
                >>> c=["yes", "no", "depends"]
                >>> i
                tensor([0.2963, 0.5161, 0.1876], dtype=torch.float64)
                >>> 
                >>> c[np.argmax(i)]
                'no'
                >>> c[torch.argmax(i)]
                'no'
                >>> type(c[torch.argmax(i)])
                <class 'str'>

            """

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask)
                probabilities = torch.nn.functional.softmax(
                    torch.tensor(logits), dim=-1)

                for prediction in probabilities:  # predict is one by one, so the length of probabilities=1
                    pred_label = label_list[np.argmax(prediction)]

            predict.append(pred_label)

        output_predict_file = os.path.join(args.output_dir,
                                           "predict_results.txt")

        with open(output_predict_file, "w") as writer:
            logger.info("***** Predict results *****")
            num_written_lines = 0
            for i in predict:
                num_written_lines += 1
                writer.write(i + "\n")

        assert num_written_lines == num_actual_predict_examples
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--model",
        default="bert",
        type=str,
        required=True,
        help="The model used for pretraining. Currently support bert or electra"
    )
    parser.add_argument(
        "--config_file",
        "--cf",
        help="pointer to the configuration file of the experiment",
        type=str,
        required=True)
    parser.add_argument(
        "--config_file_path",
        default=None,
        type=str,
        required=True,
        help="The blob storage directory where config file is located.")
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--checkpoint_file",
        default=None,
        type=str,
        help=
        "The path to checkpoint file which will be used to initializ the model parameters."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--optimize_on_cpu',
        default=False,
        action='store_true',
        help=
        "Whether to perform optimization and keep the optimizer averages on CPU"
    )
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=128,
        help=
        'Loss scaling, positive power of 2 values can improve fp16 convergence.'
    )
    parser.add_argument('--step_per_log',
                        type=int,
                        default=5,
                        help='Number of updates steps to log metrics.')
    parser.add_argument(
        "--process_count_per_node",
        default=1,
        type=int,
        help="Total number of process count to launch per node.")

    args = parser.parse_args()

    #run = Run.get_context()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "qqp": QQPProcessor,
        "qnli": QNLIProcessor,
        "sst2": SST2Processor,
        "stsb": STSBProcessor,
        "rte": RTEProcessor,
    }

    comm = DistributedCommunicator(
        accumulation_step=args.gradient_accumulation_steps)
    rank = comm.rank
    local_rank = comm.local_rank
    world_size = comm.world_size
    is_master = rank == 0

    # Prepare logger
    job_id = rutils.get_current_time()
    logger = rutils.FileLogging('%s_bert_fine_tune_%d' % (job_id, local_rank))
    logger.info("job id: %s" % job_id)
    logger.info(rutils.parser_args_to_dict(args))

    logger.info(
        "world size: {}, local rank: {}, global rank: {}, fp16: {}".format(
            world_size, local_rank, rank, args.fp16))

    torch.cuda.set_device(local_rank)
    device = torch.device("cuda", local_rank)
    hostname = socket.gethostname()
    n_gpu = torch.cuda.device_count()
    logger.info("host: {}, device: {}, n_gpu: {}".format(
        hostname, device, n_gpu))

    # extract config
    job_config = BertJobConfiguration(
        config_file_path=os.path.join(args.config_file_path, args.config_file))

    #if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
    #    raise ValueError("Output directory () already exists and is not empty.")
    #os.makedirs(args.output_dir, exist_ok=True)
    output_model_file = os.path.join(args.output_dir,
                                     job_id + "_pytorch_model_fine_tune.bin")

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    if local_rank == -1:
        args.train_batch_size = int(args.train_batch_size /
                                    args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    task_name = args.task_name.lower()

    is_master = (local_rank == -1 or rank == 0)
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(job_config.get_token_file_type(),
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)
    num_labels = len(processor.get_labels())

    # Prepare model
    model_name = args.model
    model_config = job_config.get_model_config()
    if model_name == 'bert':
        config = BertConfig(**model_config)
        config.vocab_size = len(tokenizer.vocab)
        model = BertForSequenceClassification(config, num_labels=num_labels)
    elif model_name == 'electra':
        config = ElectraConfig(**model_config)
        config.vocab_size = len(tokenizer.vocab)
        model = ElectraForSequenceClassification(config, num_labels=num_labels)

    #model = BertForSequenceClassification.from_pretrained(args.bert_model,
    #            cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(local_rank), num_labels=num_labels)

    # Load checkpoint if specified
    #import pdb;pdb.set_trace()
    if os.path.exists(str(args.checkpoint_file)):
        state_dict = torch.load(args.checkpoint_file)
        if model_name == 'bert':
            model.bert.load_state_dict(state_dict)
        elif model_name == 'electra':
            model.electra.load_state_dict(state_dict)
        logger.info("Set the model parameter from the checkpoint %s" %
                    args.checkpoint_file)

    if args.fp16:
        model.half()
    model.to(device)
    comm.register_model(model, args.fp16)

    if args.do_train:

        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        t_total = num_train_steps // world_size

        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to run this."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=t_total)
        if is_master:
            logger.info('lr: {}'.format(np.float(args.learning_rate)))

        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, logger)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d" % (len(train_examples)))
        logger.info("  Batch size = %d" % (args.train_batch_size))
        logger.info("  Num steps = %d" % (num_train_steps))
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if local_rank != -1 and world_size > 1:
            train_sampler = DistributedSampler(train_data)
        else:
            train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        global_step, tr_loss = 0, 0
        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for _, batch in enumerate(tqdm(train_dataloader,
                                           desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                loss = loss / args.gradient_accumulation_steps
                loss.backward()
                global_step += 1
                tr_loss += loss.item()
                if comm.synchronize():
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    model.zero_grad()
                if is_master and (global_step + 1) % args.step_per_log == 0:
                    logger.info('train_loss: {}'.format(
                        np.float(tr_loss / args.step_per_log)))
                    tr_loss = 0
        if is_master:
            # Save a trained model
            torch.save(model.state_dict(), output_model_file)
            logger.info('model checkpoint saved at %s' % output_model_file)

    if args.do_eval and is_master:
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, logger)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d" % len(eval_examples))
        logger.info("  Batch size = %d" % args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)
            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy
            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}
        logger.info("***** Eval results *****")
        for key in sorted(result.keys()):
            logger.info("  %s = %s" % (key, str(result[key])))
示例#11
0
文件: main.py 项目: byew/n2c2_track1
def main():
    args = parse_args()

    # specifies the path where the biobert or clinical bert model is saved
    if args.bert_model == 'biobert' or args.bert_model == 'clinical_bert':
        args.bert_model = args.model_loc

    print(f"Using bert model: {args.bert_model}")

    device = torch.device(args.device if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info(f"device: {device} n_gpu: {n_gpu}")

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    processor = N2c2ClsProcessor(args.fold_id)
    num_labels = 13
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    print('TRAIN')
    train = processor.get_train_examples(args.data_dir)
    print([(train[i].text_a, train[i].text_b, train[i].label)
           for i in range(3)])
    print('DEV')
    dev = processor.get_dev_examples(args.data_dir)
    print([(dev[i].text_a, dev[i].text_b, dev[i].label) for i in range(3)])
    print('TEST')
    test = processor.get_test_examples(args.data_dir)
    print([(test[i].text_a, test[i].text_b, test[i].label) for i in range(3)])

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) /
            args.train_batch_size) * args.num_train_epochs

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else PYTORCH_PRETRAINED_BERT_CACHE
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    model.to(device)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        train_sampler = RandomSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                optimizer.step()
                optimizer.zero_grad()
                global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)
    model.to(device)

    if args.do_eval:
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        pred = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)
                logits = torch.softmax(logits, 1)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)
            pred += logits.tolist()

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None

        pred = {f.guid: p for f, p in zip(eval_features, pred)}

        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

        output_pred_file = os.path.join(args.output_dir, "pred_results.txt")
        with open(output_pred_file, 'w') as writer:
            logger.info("***** Writing Eval predictions *****")
            for id, p in pred.items():
                writer.write(f"{id}:{p}\n")

    if args.do_test and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        test_examples = processor.get_test_examples(args.data_dir)
        test_features = convert_examples_to_features(test_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running testing *****")
        logger.info("  Num examples = %d", len(test_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in test_features],
                                     dtype=torch.long)
        test_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        test_loss, test_accuracy = 0, 0
        nb_test_steps, nb_test_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                test_dataloader, desc="Testing"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_test_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_test_accuracy = accuracy(logits, label_ids)

            test_loss += tmp_test_loss.mean().item()
            test_accuracy += tmp_test_accuracy

            nb_test_examples += input_ids.size(0)
            nb_test_steps += 1

        test_loss = test_loss / nb_test_steps
        test_accuracy = test_accuracy / nb_test_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {
            'test_loss': test_loss,
            'test_accuracy': test_accuracy,
            'global_step': global_step,
            'loss': loss
        }

        output_test_file = os.path.join(args.output_dir, "test_results.txt")
        with open(output_test_file, "w") as writer:
            logger.info("***** Test results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
示例#12
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor,
        "adlhw2": MyTaskProcessor
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification",
        "adlhw2": "classification"
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
    output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
    output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
    print("output_model_file: ", output_model_file)
    print("output_config_file: ", output_config_file)
    print("Load the config!!!") 
    config = BertConfig(output_config_file)
    model = BertForSequenceClassification(config, num_labels=num_labels)
    model.load_state_dict(torch.load(output_model_file))
    #model = BertForSequenceClassification.from_pretrained(args.bert_model,
    #          cache_dir=cache_dir,
    #          num_labels=num_labels)
    #print("model_1: ", model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
       {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        #output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        #output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        #print("output_model_file: ", output_model_file)
        #print("output_config_file: ", output_config_file)
        #print("Load the config!!!") 
        print("model", model)
        #config = BertConfig(output_config_file)
        #model = BertForSequenceClassification(config, num_labels=num_labels)
        #model.load_state_dict(torch.load(output_model_file))
        #model= torch.load("/home/tzutengweng/ADLHW/A2/code_new/bert_output_1/pytorch_model.bin")
    #print("model", model)
    #input()
    model.to(device)
    #print("model: ", model)
    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        #print("eval_examples: ", eval_examples) #a list of <__main__.InputExample object at 0x7f61b67bef28>
        #input()
        Ids = [e.guid for e in eval_examples]
        #print("Ids: ", Ids) #a list of ids
        #input()
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        #print("eval_features: ", eval_features) #a list of <__main__.InputFeatures object at 0x7f86f15cb400>
        #input()
        logger.info("***** Running testing *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            #print("input_ids: ", input_ids)
            #print("input_mask: ", input_mask)
            #print("segment_ids: ", segment_ids)
            print("labels_ids: ", label_ids) #labels_ids:  tensor([1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0'
            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            print("logits: ", logits.shape)
            print("logits: ", logits)
            input()
            # create eval loss and other metric required by the task
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
            elif output_mode == "regression":
                loss_fct = MSELoss()
                tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
            
            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
                #print("logits: ", logits.detach().cpu().numpy())
                #input()
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)
                #print("preds: ", preds)
                #print("preds[0]: ", preds[0])
                #input()
                #numpy.append: Append values to the end of an array.
                #https://docs.scipy.org/doc/numpy/reference/generated/numpy.append.html

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        #print("preds: ", preds.shape) #preds:  (2210, 5)
        #input()
        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)
            #print("preds: ", preds) #[2 4 3 ... 4 1 0] returning the index, predicted label will be index+1
            #print("preds: ", preds.shape) #preds:  (2210,)
        elif output_mode == "regression":
            preds = np.squeeze(preds)
        
        
        result = compute_metrics(task_name, preds, all_label_ids.numpy())
        loss = tr_loss/nb_tr_steps if args.do_train else None

        result['eval_loss'] = eval_loss
        result['global_step'] = global_step
        result['loss'] = loss

        output_eval_file = os.path.join(args.output_dir, "test.csv")
        output_test_file_1 = os.path.join(args.output_dir, "test_results.txt")
        
       # with open(output_test_file_1, "w") as writer:
       #     logger.info("***** Storing test results *****")
       #     for key in sorted(result.keys()):
       #         logger.info("  %s = %s", key, str(result[key]))
       #         writer.write("%s = %s\n" % (key, str(result[key])))
        with open(output_eval_file, "w") as f:
            logger.info("***** Test results *****")
            writer = csv.DictWriter(f, fieldnames=['Id', 'label'])
            writer.writeheader()
            writer.writerows(
               [{'Id': Id, 'label': p + 1} for Id, p in zip(Ids, preds)])


        # hack for MNLI-MM
        if task_name == "mnli":
            task_name = "mnli-mm"
            processor = processors[task_name]()

            if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
                raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
            if not os.path.exists(args.output_dir + '-MM'):
                os.makedirs(args.output_dir + '-MM')

            eval_examples = processor.get_dev_examples(args.data_dir)
            eval_features = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
            logger.info("***** Running evaluation *****")
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)
            all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
            all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

            model.eval()
            eval_loss = 0
            nb_eval_steps = 0
            preds = []

            for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(input_ids, segment_ids, input_mask, labels=None)
            
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
            
                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_steps += 1
                if len(preds) == 0:
                    preds.append(logits.detach().cpu().numpy())
                else:
                    preds[0] = np.append(
                        preds[0], logits.detach().cpu().numpy(), axis=0)

            eval_loss = eval_loss / nb_eval_steps
            preds = preds[0]
            preds = np.argmax(preds, axis=1)
            result = compute_metrics(task_name, preds, all_label_ids.numpy())
            loss = tr_loss/nb_tr_steps if args.do_train else None

            result['eval_loss'] = eval_loss
            result['global_step'] = global_step
            result['loss'] = loss

            output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
示例#13
0
class Bert_trained_model:
    def __init__(self):
        parser = argparse.ArgumentParser()

        parser.add_argument(
            "--bert_model",
            default=None,
            type=str,
            required=True,
            help=
            "Bert pre-trained model selected in the list: bert-base-uncased, "
            "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
            "bert-base-multilingual-cased, bert-base-chinese.")
        parser.add_argument("--task_name",
                            default='mnli',
                            type=str,
                            required=True,
                            help="The name of the task to train.")
        parser.add_argument(
            "--output_dir",
            default=None,
            type=str,
            required=True,
            help=
            "The output directory where the model predictions and checkpoints will be written."
        )

        ## Other parameters
        parser.add_argument(
            "--cache_dir",
            default="",
            type=str,
            help=
            "Where do you want to store the pre-trained models downloaded from s3"
        )
        parser.add_argument(
            "--max_seq_length",
            default=128,
            type=int,
            help=
            "The maximum total input sequence length after WordPiece tokenization. \n"
            "Sequences longer than this will be truncated, and sequences shorter \n"
            "than this will be padded.")
        parser.add_argument("--do_train",
                            action='store_true',
                            help="Whether to run training.")
        parser.add_argument("--do_eval",
                            action='store_true',
                            help="Whether to run eval on the dev set.")
        parser.add_argument(
            "--do_lower_case",
            action='store_true',
            help="Set this flag if you are using an uncased model.")
        parser.add_argument("--train_batch_size",
                            default=32,
                            type=int,
                            help="Total batch size for training.")
        parser.add_argument("--eval_batch_size",
                            default=1,
                            type=int,
                            help="Total batch size for eval.")
        parser.add_argument("--learning_rate",
                            default=5e-5,
                            type=float,
                            help="The initial learning rate for Adam.")
        parser.add_argument("--num_train_epochs",
                            default=3.0,
                            type=float,
                            help="Total number of training epochs to perform.")
        parser.add_argument(
            "--warmup_proportion",
            default=0.1,
            type=float,
            help=
            "Proportion of training to perform linear learning rate warmup for. "
            "E.g., 0.1 = 10%% of training.")
        parser.add_argument("--no_cuda",
                            action='store_true',
                            help="Whether not to use CUDA when available")
        parser.add_argument("--local_rank",
                            type=int,
                            default=-1,
                            help="local_rank for distributed training on gpus")
        parser.add_argument('--seed',
                            type=int,
                            default=42,
                            help="random seed for initialization")
        parser.add_argument(
            '--gradient_accumulation_steps',
            type=int,
            default=1,
            help=
            "Number of updates steps to accumulate before performing a backward/update pass."
        )
        parser.add_argument(
            '--fp16',
            action='store_true',
            help="Whether to use 16-bit float precision instead of 32-bit")
        parser.add_argument(
            '--loss_scale',
            type=float,
            default=0,
            help=
            "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
            "0 (default value): dynamic loss scaling.\n"
            "Positive power of 2: static loss scaling value.\n")
        parser.add_argument('--server_ip',
                            type=str,
                            default='',
                            help="Can be used for distant debugging.")
        parser.add_argument('--server_port',
                            type=str,
                            default='',
                            help="Can be used for distant debugging.")
        parser.add_argument('--gpu_id',
                            type=str,
                            default='',
                            help="GPU to use")

        self.args = parser.parse_args()
        if self.args.server_ip and self.args.server_port:
            # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
            import ptvsd
            print("Waiting for debugger attach")
            ptvsd.enable_attach(address=(self.args.server_ip,
                                         self.args.server_port),
                                redirect_output=True)
            ptvsd.wait_for_attach()

        self.processors = {"mnli": MnliProcessor}

        self.num_labels_task = {"mnli": 3}
        os.environ["CUDA_VISIBLE_DEVICES"] = self.args.gpu_id
        if self.args.local_rank == -1 or self.args.no_cuda:
            self.device = torch.device("cuda" if torch.cuda.is_available()
                                       and not self.args.no_cuda else "cpu")
            n_gpu = torch.cuda.device_count()
        else:
            torch.cuda.set_device(self.args.local_rank)
            self.device = torch.device("cuda", self.args.local_rank)
            n_gpu = 1
            # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
            torch.distributed.init_process_group(backend='nccl')
        logger.info(
            "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}"
            .format(self.device, n_gpu, bool(self.args.local_rank != -1),
                    self.args.fp16))

        if self.args.gradient_accumulation_steps < 1:
            raise ValueError(
                "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
                .format(self.args.gradient_accumulation_steps))

        self.args.train_batch_size = self.args.train_batch_size // self.args.gradient_accumulation_steps

        random.seed(self.args.seed)
        np.random.seed(self.args.seed)
        torch.manual_seed(self.args.seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(self.args.seed)

        if not self.args.do_train and not self.args.do_eval:
            raise ValueError(
                "At least one of `do_train` or `do_eval` must be True.")

        if os.path.exists(self.args.output_dir) and os.listdir(
                self.args.output_dir) and self.args.do_train:
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(self.args.output_dir))
        if not os.path.exists(self.args.output_dir):
            os.makedirs(self.args.output_dir)

        task_name = self.args.task_name.lower()

        if task_name not in self.processors:
            raise ValueError("Task not found: %s" % (task_name))

        self.processor = self.processors[task_name]()
        self.num_labels = self.num_labels_task[task_name]
        self.label_list = self.processor.get_labels()

        train_examples = None
        num_train_optimization_steps = None
        if self.args.do_train:
            train_examples = processor.get_train_examples(self.args.data_dir)
            num_train_optimization_steps = int(
                len(train_examples) / self.args.train_batch_size / self.args.
                gradient_accumulation_steps) * self.args.num_train_epochs
            if self.args.local_rank != -1:
                num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
                )

        # Prepare model
        cache_dir = self.args.cache_dir if self.args.cache_dir else os.path.join(
            str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
                self.args.local_rank))
        self.model = BertForSequenceClassification.from_pretrained(
            self.args.bert_model,
            cache_dir=cache_dir,
            num_labels=self.num_labels)
        if self.args.fp16:
            self.model.half()
        self.model.to(self.device)
        if self.args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            self.model = DDP(self.model)
        elif n_gpu > 1:
            self.model = torch.nn.DataParallel(self.model)

        # Prepare optimizer
        param_optimizer = list(self.model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        if self.args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=self.args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if self.args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(
                    optimizer, static_loss_scale=self.args.loss_scale)

        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=self.args.learning_rate,
                                 warmup=self.args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

        global_step = 0
        nb_tr_steps = 0
        tr_loss = 0

        self.tokenizer = BertTokenizer.from_pretrained(
            self.args.bert_model, do_lower_case=self.args.do_lower_case)

        output_model_file = os.path.join(self.args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(self.args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        self.model = BertForSequenceClassification(config,
                                                   num_labels=self.num_labels)
        self.model.load_state_dict(torch.load(output_model_file))
        self.model.to(self.device)

    def predict(self, s1, s2):
        eval_examples = []
        # label is dummy
        eval_examples.append(
            InputExample(guid=1, text_a=s1, text_b=s2, label='entailment'))
        eval_examples.append(
            InputExample(guid=2, text_a=s2, text_b=s1, label='entailment'))
        eval_features = convert_examples_to_features(eval_examples,
                                                     self.label_list,
                                                     self.args.max_seq_length,
                                                     self.tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", self.args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=self.args.eval_batch_size)

        self.model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        odds = []
        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(self.device)
            input_mask = input_mask.to(self.device)
            segment_ids = segment_ids.to(self.device)
            label_ids = label_ids.to(self.device)

            with torch.no_grad():
                logits = self.model(input_ids, segment_ids, input_mask)

            # logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            # outputs = np.argmax(logits, axis=1)
            m = torch.nn.Softmax(dim=1)
            outputs = m(logits)

            p = np.array([f[1] / (1 - f[1]) for f in outputs])
            odds.extend(p)

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        return odds
示例#14
0
def run(args):
    print(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}".format(device, n_gpu))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")
    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    processor = MnliProcessor()
    label_list = processor.get_labels()
    num_labels = len(label_list)
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.train_dataset)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )
    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_1')
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    model.to(device)
    if n_gpu > 1:
        model = torch.nn.DataParallel(model)
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = BertAdam(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_optimization_steps)
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        model_state_dict = torch.load(output_model_file)
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model,
            num_labels=num_labels,
            state_dict=model_state_dict)
    model.to(device)
    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.eval_dataset)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
示例#15
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_config_file",
        default=None,
        type=str,
        required=True,
        help=
        "The config json file corresponding to the pre-trained BERT model. \n"
        "This specifies the model architecture.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--vocab_file",
        default=None,
        type=str,
        required=True,
        help="The vocabulary file that the BERT model was trained on.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--init_checkpoint",
        default=None,
        type=str,
        help="Initial checkpoint (usually from a pre-trained BERT model).")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--save_checkpoints_steps",
                        default=1000,
                        type=int,
                        help="How often to save the model checkpoint.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    args = parser.parse_args()

    processors = {
        "dream": dreamProcessor,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device %s n_gpu %d distributed training %r", device, n_gpu,
                bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    bert_config = BertConfig.from_json_file(args.bert_config_file)

    if args.max_seq_length > bert_config.max_position_embeddings:
        raise ValueError(
            "Cannot use sequence length {} because the BERT model was only trained up to sequence length {}"
            .format(args.max_seq_length, bert_config.max_position_embeddings))

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        if args.do_train:
            raise ValueError(
                "Output directory ({}) already exists and is not empty.".
                format(args.output_dir))
    else:
        os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels()

    tokenizer = tokenization.FullTokenizer(vocab_file=args.vocab_file,
                                           do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / n_class / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    model = BertForSequenceClassification(
        bert_config, 1 if n_class > 1 else len(label_list))
    if args.init_checkpoint is not None:
        model.bert.load_state_dict(
            torch.load(args.init_checkpoint, map_location='cpu'))
    model.to(device)

    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model, device_ids=[args.local_rank], output_device=args.local_rank)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    no_decay = ['bias', 'gamma', 'beta']
    optimizer_parameters = [{
        'params':
        [p for n, p in model.named_parameters() if n not in no_decay],
        'weight_decay_rate':
        0.01
    }, {
        'params': [p for n, p in model.named_parameters() if n in no_decay],
        'weight_decay_rate':
        0.0
    }]

    optimizer = BERTAdam(optimizer_parameters,
                         lr=args.learning_rate,
                         warmup=args.warmup_proportion,
                         t_total=num_train_steps)

    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        input_ids = []
        input_mask = []
        segment_ids = []
        label_id = []
        for f in train_features:
            input_ids.append([])
            input_mask.append([])
            segment_ids.append([])
            for i in range(n_class):
                input_ids[-1].append(f[i].input_ids)
                input_mask[-1].append(f[i].input_mask)
                segment_ids[-1].append(f[i].segment_ids)
            label_id.append([f[0].label_id])

        all_input_ids = torch.tensor(input_ids, dtype=torch.long)
        all_input_mask = torch.tensor(input_mask, dtype=torch.long)
        all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        all_label_ids = torch.tensor(label_id, dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss, _ = model(input_ids, segment_ids, input_mask, label_ids,
                                n_class)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    optimizer.step()  # We have accumulated enought gradients
                    model.zero_grad()
                    global_step += 1

        torch.save(model.state_dict(), os.path.join(args.output_dir,
                                                    "model.pt"))

    model.load_state_dict(torch.load(os.path.join(args.output_dir,
                                                  "model.pt")))

    if args.do_eval:
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)

        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

        input_ids = []
        input_mask = []
        segment_ids = []
        label_id = []

        for f in eval_features:
            input_ids.append([])
            input_mask.append([])
            segment_ids.append([])
            for i in range(n_class):
                input_ids[-1].append(f[i].input_ids)
                input_mask[-1].append(f[i].input_mask)
                segment_ids[-1].append(f[i].segment_ids)
            label_id.append([f[0].label_id])

        all_input_ids = torch.tensor(input_ids, dtype=torch.long)
        all_input_mask = torch.tensor(input_mask, dtype=torch.long)
        all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        all_label_ids = torch.tensor(label_id, dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        logits_all = []
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss, logits = model(input_ids, segment_ids,
                                              input_mask, label_ids, n_class)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            for i in range(len(logits)):
                logits_all += [logits[i]]

            tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        if args.do_train:
            result = {
                'eval_loss': eval_loss,
                'eval_accuracy': eval_accuracy,
                'global_step': global_step,
                'loss': tr_loss / nb_tr_steps
            }
        else:
            result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}

        output_eval_file = os.path.join(args.output_dir,
                                        "eval_results_dev.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
        output_eval_file = os.path.join(args.output_dir, "logits_dev.txt")
        with open(output_eval_file, "w") as f:
            for i in range(len(logits_all)):
                for j in range(len(logits_all[i])):
                    f.write(str(logits_all[i][j]))
                    if j == len(logits_all[i]) - 1:
                        f.write("\n")
                    else:
                        f.write(" ")

        eval_examples = processor.get_test_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)

        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)

        input_ids = []
        input_mask = []
        segment_ids = []
        label_id = []

        for f in eval_features:
            input_ids.append([])
            input_mask.append([])
            segment_ids.append([])
            for i in range(n_class):
                input_ids[-1].append(f[i].input_ids)
                input_mask[-1].append(f[i].input_mask)
                segment_ids[-1].append(f[i].segment_ids)
            label_id.append([f[0].label_id])

        all_input_ids = torch.tensor(input_ids, dtype=torch.long)
        all_input_mask = torch.tensor(input_mask, dtype=torch.long)
        all_segment_ids = torch.tensor(segment_ids, dtype=torch.long)
        all_label_ids = torch.tensor(label_id, dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            eval_sampler = SequentialSampler(eval_data)
        else:
            eval_sampler = DistributedSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        logits_all = []
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss, logits = model(input_ids, segment_ids,
                                              input_mask, label_ids, n_class)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            for i in range(len(logits)):
                logits_all += [logits[i]]

            tmp_eval_accuracy = accuracy(logits, label_ids.reshape(-1))

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        if args.do_train:
            result = {
                'eval_loss': eval_loss,
                'eval_accuracy': eval_accuracy,
                'global_step': global_step,
                'loss': tr_loss / nb_tr_steps
            }
        else:
            result = {'eval_loss': eval_loss, 'eval_accuracy': eval_accuracy}

        output_eval_file = os.path.join(args.output_dir,
                                        "eval_results_test.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
        output_eval_file = os.path.join(args.output_dir, "logits_test.txt")
        with open(output_eval_file, "w") as f:
            for i in range(len(logits_all)):
                for j in range(len(logits_all[i])):
                    f.write(str(logits_all[i][j]))
                    if j == len(logits_all[i]) - 1:
                        f.write("\n")
                    else:
                        f.write(" ")
示例#16
0
class TextSentiment(nn.Module):
    softmax = nn.Softmax(dim=-1)

    def __init__(self, vocab_size=1308844, embed_dim=32, num_class=4):
        super().__init__()
        self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=True)
        self.fc = nn.Linear(embed_dim, num_class)
        self.init_weights()

    def init_weights(self):
        MODEL_DIR = "/models/intents"
        device_name = "cuda" if torch.cuda.is_available() else "cpu"
        print(device_name)
        self.device = torch.device(device_name)
        self.labelencoder = preprocessing.LabelEncoder()
        self.labelencoder.classes_ = np.load(os.path.join(MODEL_DIR, 'classes.npy'))
        config = BertConfig(os.path.join(MODEL_DIR, 'bert_config.json'))
        self.model = BertForSequenceClassification(config, num_labels=len(self.labelencoder.classes_))
        self.model.load_state_dict(torch.load(os.path.join(MODEL_DIR, 'pytorch_model.bin'), map_location="cpu"))
        self.model.to(self.device)
        self.model.eval()
        tokenizer_class, pretrained_weights = BertTokenizer, 'bert-base-uncased'
        self.tokenizer = tokenizer_class.from_pretrained(pretrained_weights)
        self.batch_size = 30
        self.dataloader_num_workers = 0

    def forward(self, requests):
        ids = []
        segment_ids = []
        input_masks = []

        print(requests)
        for sen in [requests]:
            text_tokens = self.tokenizer.tokenize(sen)
            tokens = ["[CLS]"] + text_tokens + ["[SEP]"]
            temp_ids = self.tokenizer.convert_tokens_to_ids(tokens)
            input_mask = [1] * len(temp_ids)
            segment_id = [0] * len(temp_ids)
            padding = [0] * (MAX_LEN - len(temp_ids))

            temp_ids += padding
            input_mask += padding
            segment_id += padding

            ids.append(temp_ids)
            input_masks.append(input_mask)
            segment_ids.append(segment_id)

        ## Convert input list to Torch Tensors
        ids = torch.tensor(ids)
        segment_ids = torch.tensor(segment_ids)
        input_masks = torch.tensor(input_masks)
        validation_data = TensorDataset(ids, input_masks, segment_ids)
        validation_sampler = SequentialSampler(validation_data)
        validation_dataloader = DataLoader(validation_data, sampler=validation_sampler, batch_size=self.batch_size,
                                           num_workers=self.dataloader_num_workers)

        responses = []
        for batch in validation_dataloader:
            # Add batch to GPU
            batch = tuple(t.to(self.device) for t in batch)
            # Unpack the inputs from our dataloader
            b_input_ids, b_input_mask, b_labels = batch
            with torch.no_grad():
                # Forward pass, calculate logit predictions
                logits = self.model(b_input_ids, token_type_ids=None, attention_mask=b_input_mask)

                for i in range(logits.size(0)):
                    label_idx = [self.__class__.softmax(logits[i]).detach().cpu().numpy().argmax()]
                    label_str = self.labelencoder.inverse_transform(label_idx)[0]
                    responses.append(label_str)

            _t1 = datetime.now()
        return responses[0]

# torch-model-archiver --model-name bert --version 1.0 --model-file ~/work/serve/examples/bert/models.py --serialized-file /models/intents/pytorch_model.bin --extra-files /models/intents/bert_config.bin --handler text
class TransformerAgent(Agent):
    @staticmethod
    def add_cmdline_args(argparser):
        agent_args = argparser.add_argument_group('Agent parameters')
        agent_args.add_argument("--model_checkpoint", type=str, default="./runs/Sep10_22-10-31_krusty/", help="Path, url or short name of the model")   # "./runs/Jun03_00-25-57_krusty/"   All empty model: Aug17_00-03-04_krusty
        agent_args.add_argument("--eval_type", type=str, default="f1", help="hits@1, ppl or f1")   # please don't change this parameter
        # agent_args.add_argument("--model", type=str, default="openai-gpt", help="Model type (gpt or gpt2)")
        agent_args.add_argument("--max_history", type=int, default=2, help="Number of previous utterances to keep in history")
        agent_args.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", help="Device (cuda or cpu)")
        agent_args.add_argument("--no_sample", action='store_true')
        agent_args.add_argument("--max_length", type=int, default=20)
        agent_args.add_argument("--min_length", type=int, default=1)
        agent_args.add_argument("--seed", type=int, default=42)   # 0
        agent_args.add_argument("--temperature", type=int, default=0.7)
        agent_args.add_argument("--top_k", type=int, default=0)   # 20
        agent_args.add_argument("--top_p", type=float, default=0.9)  # del

        # NLI
        agent_args.add_argument("--do_lower_case", type=bool, default=True,
                            help="Set this flag if you are using an uncased model.")
        agent_args.add_argument("--output_dir", default='nli_output/', type=str,
                            help="The output directory where the model predictions and checkpoints will be written.")
        agent_args.add_argument("--bert_model", default='bert-base-uncased', type=str,
                            help="Bert pre-trained model selected in the list: bert-base-uncased, "
                                 "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                                 "bert-base-multilingual-cased, bert-base-chinese.")
        # LM
        agent_args.add_argument("--lm_model_path", type=str, default='openai-gpt', help="Path of language model.")
        agent_args.add_argument("--lm_output_dir", type=str, default='lm_models/gpt_output', help="Output dir of language model.")

        return argparser

    def __init__(self, opt, shared=None):
        super(TransformerAgent, self).__init__(opt, shared)

        args = AttrDict(opt)  # to keep most commands identical to the interact.py script
        self.args = args

        logging.basicConfig(level=logging.INFO)
        self.logger = logging.getLogger(__file__)
        self.logger.info(pformat(args))

        random.seed(args.seed)
        torch.random.manual_seed(args.seed)
        torch.cuda.manual_seed(args.seed)

        if shared is None:
            self.logger.info("Get pretrained model and tokenizer")
            if args.model_checkpoint == "":
                args.model_checkpoint = download_pretrained_model()

            self.tokenizer = OpenAIGPTTokenizer.from_pretrained(args.model_checkpoint)
            if self.args.eval_type == "hits@1":
                self.model_checkpoint = OpenAIGPTDoubleHeadsModel.from_pretrained(args.model_checkpoint)
            else:
                self.model_checkpoint = OpenAIGPTLMHeadModel.from_pretrained(args.model_checkpoint)
            self.model_checkpoint.to(args.device)
            self.model_checkpoint.eval()

            self.logger.info("Build BPE prefix dictionary")
            convai_dict = build_dict()
            assert len(convai_dict) == 19304
            self.prefix2words = self.get_prefix2words(convai_dict)
        else:
            self.model_checkpoint = shared['model']
            self.tokenizer = shared['tokenizer']
            self.prefix2words = shared['prefix2words']

        self.special_tokens_ids = self.tokenizer.convert_tokens_to_ids(SPECIAL_TOKENS)

        self.persona = []
        self.history = []
        self.labels = []

        self.reward = []
        self.nli_scores = np.array([0, 0, 0])
        self.reward_scores = 0   # reward function
        self.c_scores = 0   # C score
        self.cnm = 0   # C_new
        self.sample_num = 0   # sample number
        self.con_en = np.array([0, 0, 0])   # if the persona contains a contradicted/entail profile (not applied)
        self.intrep_scores = 0   # internal repetition score
        self.lm_ppl_scores = 0   # fine-tuned GPT-based language model
        self.bleu_scores = 0   # BLEU-2 score

        # Loading NLI models
        reset_seed(args.seed)
        self.nli_tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        # print('config_file:', output_config_file)
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        # print('model_file:', output_model_file)

        nli_config = BertConfig(output_config_file)
        self.nli_model = BertForSequenceClassification(nli_config, num_labels=3)
        self.nli_model.load_state_dict(torch.load(output_model_file))
        self.nli_model.to(args.device)
        self.nli_model.eval()

        # Loading LM models
        reset_seed(args.seed)
        self.lm_special_tokens = ['_start_', '_delimiter_', '_classify_']   # special tokens for LM
        # Load pre-trained model (weights)
        with torch.no_grad():
            lm_output_config_file = os.path.join(args.lm_output_dir, CONFIG_NAME)
            lm_config = OpenAIGPTConfig(lm_output_config_file)
            print(type(lm_config))
            if not isinstance(lm_config, OpenAIGPTConfig):
                print('NOT')
            lm_output_model_file = os.path.join(args.lm_output_dir, WEIGHTS_NAME)
            lm_model_state_dict = torch.load(lm_output_model_file)
            self.lm_model = OpenAIGPTLMHeadModel(lm_config)
            self.lm_model.load_state_dict(lm_model_state_dict)

            # Load pre-trained model tokenizer (vocabulary)
            self.lm_tokenizer = OpenAIGPTTokenizer.from_pretrained(args.lm_model_path, special_tokens=self.lm_special_tokens)
        self.special_tokens_ids = list(self.lm_tokenizer.convert_tokens_to_ids(token) for token in self.lm_special_tokens)
        self.lm_model.to(args.device)
        self.lm_model.eval()

        reset_seed(args.seed)

        self.reset()

    def observe(self, observation):
        if self.episode_done:
            self.reset()

        if self.labels:
            # Add the previous response to the history
            self.history.append(self.labels)

        if 'labels' in observation or 'eval_labels' in observation:
            text = observation.get('labels', observation.get('eval_labels', [[]]))[0]
            self.labels = self.tokenizer.encode(text)

        if 'text' in observation:
            text = observation['text']
            for subtext in text.split('\n'):
                subtext = subtext.strip()
                if subtext.startswith('your persona:'):
                    subtext = subtext.replace('your persona:', '').strip()
                    self.persona.append(self.tokenizer.encode(subtext))
                else:
                    self.history.append(self.tokenizer.encode(subtext))

        self.history = self.history[-(2*self.args.max_history+1):]

        candidates = []
        if 'label_candidates' in observation:
            for candidate in observation['label_candidates']:
                candidates.append((self.tokenizer.encode(candidate), candidate))
        self.candidates = candidates

        self.episode_done = observation['episode_done']
        self.observation = observation
        return observation

    def act(self):
        reply = {}

        if self.args.eval_type == "hits@1" and len(self.candidates) > 0:
            instances = defaultdict(list)
            for candidate, _ in self.candidates:
                instance, _ = build_input_from_segments(self.persona, self.history, candidate, self.tokenizer)
                for input_name, input_array in instance.items():
                    instances[input_name].append(input_array)

            inputs = pad_dataset(instances, padding=self.special_tokens_ids[-1])

            tensor_inputs = {}
            for input_name in ["input_ids", "mc_token_ids", "token_type_ids"]:
                tensor = torch.tensor(inputs[input_name], device=self.args.device)
                tensor = tensor.view((-1, len(self.candidates)) + tensor.shape[1:])
                tensor_inputs[input_name] = tensor

            with torch.no_grad():
                _, mc_logits = self.model_checkpoint(**tensor_inputs)

            val, ind = torch.sort(mc_logits[0], descending=True)

            ypred = self.candidates[ind[0].item()][1] # match
            tc = []
            for j in range(len(self.candidates)):
                tc.append(self.candidates[ind[j].item()][1])
            reply = {'text': ypred, 'text_candidates': tc}
        else:
            # We are in interactive of f1 evaluation mode => just sample
            with torch.no_grad():
                out_ids = sample_sequence(self.persona, self.history, self.tokenizer, self.model_checkpoint, self.args)   # YW: TODO: out_ids, _?
            # Get a generated response
            out_text = self.tokenizer.decode(out_ids, skip_special_tokens=True,
                                             clean_up_tokenization_spaces=(self.args.eval_type != 'f1'))
            out_text_org = out_text
            out_text = out_text.replace(' \' ', '\'')   # TODO: tbd
            out_text = out_text.replace(' \'', '\'')
            # persona NLI
            profiles = []
            for profile in self.persona:
                profile_text = self.tokenizer.decode(profile, skip_special_tokens=True, clean_up_tokenization_spaces=False)
                profile_text = profile_text.replace(' \' ', '\'')   # TODO: tbd
                profile_text = profile_text.replace(' \'', '\'')
                profiles.append(profile_text)
            nli_score, reward_score, c_score, current_con_en = nli_engine(out_text, profiles, self.nli_tokenizer, self.nli_model, eval=True)
            self.nli_scores += nli_score   # persona NLI
            self.reward_scores += reward_score   # reward function
            self.c_scores += c_score   # C score
            self.sample_num += 1
            self.con_en += current_con_en   # if this persona contains a contradicted/entail profile or not (not applied)

            # internal repetition
            response_tok = out_text_org.split()
            intrep_1gram = intrep_frac(response_tok)
            # if 2-gram or 3-gram are going to be used:
            ''''
            # intrep_2gram
            response_tok_2gram = get_ngrams(out_text, 2)
            intrep_2gram = intrep_frac(response_tok_2gram)
            # intrep_3gram
            response_tok_3gram = get_ngrams(out_text, 3)
            intrep_3gram = intrep_frac(response_tok_3gram)
            '''
            intern_rep_reward = intrep_1gram
            self.intrep_scores += intern_rep_reward

            # bleu
            label_text = self.tokenizer.decode(self.labels, skip_special_tokens=True, clean_up_tokenization_spaces=False)
            current_bleu = bleu_rewarder(out_text_org, label_text)
            self.bleu_scores += current_bleu

            # fine-tuned GPT-based language model
            lm_tokenize_input = self.lm_tokenizer.tokenize(out_text)
            # lm_tensor_input = torch.tensor([lm_tokenizer.convert_tokens_to_ids(lm_tokenize_input)]).to(args.device)
            lm_tensor_input = torch.tensor([[self.special_tokens_ids[0]] + self.lm_tokenizer.convert_tokens_to_ids(lm_tokenize_input) + [self.special_tokens_ids[-1]]]).to(self.args.device)
            lm_loss = self.lm_model(lm_tensor_input, lm_labels=lm_tensor_input)
            lm_ppl = math.exp(lm_loss.item())
            self.lm_ppl_scores += lm_ppl

            print('out_text:', out_text)
            print('current nli:', self.nli_scores)
            print('current score:', self.reward_scores / self.sample_num)
            print('current c_score_macro:', self.c_scores / self.sample_num)
            current_c_score_micro = (self.nli_scores[1] - self.nli_scores[0]) / sum(self.nli_scores)
            cn_res = nli_score[1] - nli_score[0]   # cn: C_new (persona level)
            # C_new calculation
            if cn_res > 0:
                current_cn = 1
            elif cn_res < 0:
                current_cn = -1
            else:
                current_cn = 0
            self.cnm += current_cn
            print('current c_new:', self.cnm / self.sample_num)
            print('current c_score_micro:', current_c_score_micro)
            print('current con_en:', self.con_en)
            print('current intrep score:', self.intrep_scores / self.sample_num)
            print('current BLEU:', self.bleu_scores / self.sample_num)
            print('current PPL:', self.lm_ppl_scores / self.sample_num)
            reply = {'text': out_text}

        return reply

    def next_word_probability(self, partial_out):
        """Return probability distribution over next words given an input and
        partial true output. This is used to calculate the per-word perplexity.
        """
        partial_out_ids = self.tokenizer.encode(' '.join(partial_out))
        instance, _ = build_input_from_segments(self.persona, self.history, partial_out_ids,
                                             self.tokenizer, with_eos=False)

        input_ids = torch.tensor(instance["input_ids"], device=self.args.device).unsqueeze(0)
        token_type_ids = torch.tensor(instance["token_type_ids"], device=self.args.device).unsqueeze(0)

        with torch.no_grad():
            logits = self.model_checkpoint(input_ids, token_type_ids=token_type_ids)

        probs = F.softmax(logits[0, -1], dim=0)

        dist = {}
        for prefix_id, words in self.prefix2words.items():
            for word, ratio in words.items():
                dist[word] = probs[prefix_id].item() * ratio
        return dist

    def get_prefix2words(self, convai_dict, smoothing_freq=5):
        """ map BPE-prefix => dict(full_words beginning with BPE-prefix, associated words_counts) """
        prefix2words = defaultdict(dict)
        for i in trange(len(convai_dict)):
            word = convai_dict[i]
            freq = convai_dict.freq[word] + smoothing_freq
            bpe_tokens = self.tokenizer.bpe(word).split(' ')
            prefix_id = self.tokenizer.convert_tokens_to_ids(bpe_tokens[0])
            prefix2words[prefix_id].update(dict([(word, freq)]))

        for prefix_id, words in prefix2words.items():
            total_counts = sum(words.values())
            prefix2words[prefix_id] = dict((word, count/total_counts) for word, count in words.items())

        return prefix2words

    def share(self):
        shared = super(TransformerAgent, self).share()
        shared['tokenizer'] = self.tokenizer
        shared['model'] = self.model_checkpoint
        shared['prefix2words'] = self.prefix2words
        return shared

    def reset(self):
        self.persona = []
        self.history = []
        self.labels = []
        self.candidates = []
        self.episode_done = True
        self.observation = None
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument(
        "--state_dir",
        default="",
        type=str,
        help=
        "Where to load state dict instead of using Google pre-trained model")
    parser.add_argument(
        "--config_path",
        default="",
        type=str,
        help="Where to load the config file when not using pretrained model")
    parser.add_argument("--teacher_model",
                        default="",
                        type=str,
                        help="teacher model bin file path")
    parser.add_argument("--teacher_config",
                        default="",
                        type=str,
                        help="teacher model config path")
    parser.add_argument("--kd_ratio",
                        default=1.0,
                        type=float,
                        help="Knowledge distillation loss ratio")
    parser.add_argument("--eval_every_epoch",
                        action='store_true',
                        help="Whether to evaluate for every epoch")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "qqp": QqpProcessor,
    }

    num_labels_task = {"cola": 2, "sst-2": 2, "mnli": 3, "mrpc": 2, "qqp": 2}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    config_file = args.teacher_config
    model_file = args.teacher_model
    config = BertConfig(config_file)
    teacher_model = BertForSequenceClassification(config,
                                                  num_labels=num_labels)
    teacher_model.load_state_dict(torch.load(model_file))
    # Prepare model
    if args.state_dir:
        config = BertConfig(args.config_path)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        state_dict = torch.load(args.state_dir)
        if 'model' in state_dict:
            state_dict = state_dict['model']
        model.load_state_dict(state_dict, strict=False)
    else:
        cache_dir = args.cache_dir if args.cache_dir else os.path.join(
            str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
                args.local_rank))
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    teacher_model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
        teacher_model = DDP(teacher_model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)
        teacher_model = torch.nn.DataParallel(teacher_model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        teacher_model.eval()
        ratio = args.kd_ratio
        for ep in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                nll_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)
                with torch.no_grad():
                    gt = F.softmax(
                        teacher_model(input_ids, segment_ids, input_mask))
                kd_loss = -F.log_softmax(logits) * gt
                kd_loss = kd_loss.mean()
                nll_loss = nll_loss.mean()
                loss = (1 - ratio) * nll_loss + ratio * kd_loss
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
            if args.eval_every_epoch:
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(args.output_dir,
                                                 WEIGHTS_NAME + str(ep))
                torch.save(model_to_save.state_dict(), output_model_file)
                output_config_file = os.path.join(args.output_dir,
                                                  CONFIG_NAME + str(ep))
                with open(output_config_file, 'w') as f:
                    f.write(model_to_save.config.to_json_string())
                # Load a trained model and config that you have fine-tuned
                config = BertConfig(output_config_file)
                model_eval = BertForSequenceClassification(
                    config, num_labels=num_labels)
                model_eval.load_state_dict(torch.load(output_model_file))
                model_eval.to(device)

                eval_examples = processor.get_dev_examples(args.data_dir)
                eval_features = convert_examples_to_features(
                    eval_examples, label_list, args.max_seq_length, tokenizer)
                logger.info("***** Running evaluation *****")
                logger.info("  Num examples = %d", len(eval_examples))
                logger.info("  Batch size = %d", args.eval_batch_size)
                all_input_ids = torch.tensor(
                    [f.input_ids for f in eval_features], dtype=torch.long)
                all_input_mask = torch.tensor(
                    [f.input_mask for f in eval_features], dtype=torch.long)
                all_segment_ids = torch.tensor(
                    [f.segment_ids for f in eval_features], dtype=torch.long)
                all_label_ids = torch.tensor(
                    [f.label_id for f in eval_features], dtype=torch.long)
                eval_data = TensorDataset(all_input_ids, all_input_mask,
                                          all_segment_ids, all_label_ids)
                # Run prediction for full data
                eval_sampler = SequentialSampler(eval_data)
                eval_dataloader = DataLoader(eval_data,
                                             sampler=eval_sampler,
                                             batch_size=args.eval_batch_size)

                model_eval.eval()
                eval_loss, eval_accuracy = 0, 0
                nb_eval_steps, nb_eval_examples = 0, 0

                for input_ids, input_mask, segment_ids, label_ids in tqdm(
                        eval_dataloader, desc="Evaluating"):
                    input_ids = input_ids.to(device)
                    input_mask = input_mask.to(device)
                    segment_ids = segment_ids.to(device)
                    label_ids = label_ids.to(device)

                    with torch.no_grad():
                        tmp_eval_loss = model_eval(input_ids, segment_ids,
                                                   input_mask, label_ids)
                        logits = model_eval(input_ids, segment_ids, input_mask)

                    logits = logits.detach().cpu().numpy()
                    label_ids = label_ids.to('cpu').numpy()
                    tmp_eval_accuracy = accuracy(logits, label_ids)

                    eval_loss += tmp_eval_loss.mean().item()
                    eval_accuracy += tmp_eval_accuracy

                    nb_eval_examples += input_ids.size(0)
                    nb_eval_steps += 1

                eval_loss = eval_loss / nb_eval_steps
                eval_accuracy = eval_accuracy / nb_eval_examples
                loss = tr_loss / nb_tr_steps if args.do_train else None
                result = {
                    'eval_loss': eval_loss,
                    'eval_accuracy': eval_accuracy,
                    'global_step': global_step,
                    'loss': loss
                }

                output_eval_file = os.path.join(args.output_dir,
                                                "eval_results.txt" + str(ep))
                with open(output_eval_file, "w") as writer:
                    logger.info("***** Eval results *****")
                    for key in sorted(result.keys()):
                        logger.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .csv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    #if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        #raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_swag_examples(os.path.join(args.data_dir, 'train.csv'), is_training = True)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    cache_dir = os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
    predictor = BertForMultipleChoice.from_pretrained(args.bert_model,
        cache_dir=cache_dir,
        num_choices=4)
    # Use independently trained adversary
    output_model_file = os.path.join(args.output_dir, 'adversary_' + WEIGHTS_NAME)
    output_config_file = os.path.join(args.output_dir, 'adversary_' + CONFIG_NAME)
    config = BertConfig(output_config_file)
    adversary = BertForSequenceClassification(config, num_labels=2)
    adversary.load_state_dict(torch.load(output_model_file))  
    
    if args.fp16:
        predictor.half()
        adversary.half()
    predictor.to(device)
    adversary.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        predictor = DDP(predictor)
        adversary = DDP(adversary)
    elif n_gpu > 1:
        predictor = torch.nn.DataParallel(predictor)
        adversary = torch.nn.DataParallel(adversary)

    # Prepare optimizer
    param_optimizer_pred = list(predictor.named_parameters())
    
    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer_pred = [n for n in param_optimizer_pred if 'pooler' not in n[0]]
    
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters_pred = [
        {'params': [p for n, p in param_optimizer_pred if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer_pred if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer_pred = FusedAdam(optimizer_grouped_parameters_pred,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)           
        if args.loss_scale == 0:
            optimizer_pred = FP16_Optimizer(optimizer_pred, dynamic_loss_scale=True)
        else:
            optimizer_pred = FP16_Optimizer(optimizer_pred, static_loss_scale=args.loss_scale)
    else:
        optimizer_pred = BertAdam(optimizer_grouped_parameters_pred,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    alpha = 1
    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, tokenizer, args.max_seq_length, True)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'), dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features, 'input_mask'), dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features, 'segment_ids'), dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features], dtype=torch.long)
        all_vp_input_ids = torch.tensor(select_field(train_features, 'vp_input_ids'), dtype=torch.long)
        all_vp_input_mask = torch.tensor(select_field(train_features, 'vp_input_mask'), dtype=torch.long)
        all_protected_attr = torch.tensor([f.protected_attr for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label, all_vp_input_ids, all_vp_input_mask, all_protected_attr)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        training_history = []
        predictor.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss_pred, tr_loss_adv = 0, 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids, vp_input_ids, vp_input_mask, protected_attr_ids = batch
                loss_pred, logits_pred = predictor(input_ids, segment_ids, input_mask, label_ids)
                softmax = torch.nn.functional.softmax(logits_pred, dim=1)

                # flatten vp ids and mask
                batch_size, num_choices = vp_input_ids.shape[0], vp_input_ids.shape[1]
                vp_input_ids = vp_input_ids.view([batch_size * num_choices, -1])
                vp_input_mask = vp_input_mask.view([batch_size * num_choices, -1])
                # repeat protected attribute number of choice times
                protected_attr_ids_ = protected_attr_ids.repeat(num_choices, 1).t()
                protected_attr_ids_ = protected_attr_ids_.reshape(-1)
                _, logits_adv = adversary(vp_input_ids, None, vp_input_mask, protected_attr_ids_)
                pos_probs = logits_adv.view([batch_size, num_choices, -1])[:,:,1]
                # perform a batch-wise dot product between positive probabilities and softmax vector
                dot_prod = torch.bmm(pos_probs.view([batch_size, 1, num_choices]), softmax.view([batch_size, num_choices, 1])).view([batch_size, 1])
                loss_adv = torch.nn.CrossEntropyLoss()(torch.cat([1 - dot_prod, dot_prod], dim=1), protected_attr_ids.view([-1]))
                
                if n_gpu > 1:
                    loss_pred = loss_pred.mean() # mean() to average on multi-gpu.
                    loss_adv = loss_adv.mean()
                if args.fp16 and args.loss_scale != 1.0:
                    # rescale loss for fp16 training
                    # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
                    loss_pred = loss_pred * args.loss_scale
                    loss_adv = loss_adv * args.loss_scale
                if args.gradient_accumulation_steps > 1:
                    loss_pred = loss_pred / args.gradient_accumulation_steps
                    loss_adv = loss_adv / args.gradient_accumulation_steps
                tr_loss_pred += loss_pred.item()
                tr_loss_adv += loss_adv.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                
                training_history.append([loss_pred.item(), loss_adv.item()])
                loss = loss_pred - alpha * loss_adv
                if args.fp16:
                    optimizer_pred.backward(loss)
                else:
                    loss.backward()
                # if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                    for param_group in optimizer_pred.param_groups:
                        param_group['lr'] = lr_this_step
                optimizer_pred.step()
                optimizer_pred.zero_grad()
                global_step += 1
        history_file = open(os.path.join(args.output_dir, "train_results.csv"), "w")
        writer = csv.writer(history_file, delimiter=",")
        writer.writerow(["pred_loss","adv_loss"])
        for row in training_history:
            writer.writerow(row)
    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = predictor.module if hasattr(predictor, 'module') else predictor  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, 'predictor_' + WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, 'predictor_' + CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())
            
        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        predictor = BertForMultipleChoice(config, num_choices=4)
        predictor.load_state_dict(torch.load(output_model_file))
        
        # Do the same for adversary
        model_to_save = adversary.module if hasattr(adversary, 'module') else adversary  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, 'adversary_' + WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, 'adversary_' + CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())
            
        config = BertConfig(output_config_file)
        adversary = BertForSequenceClassification(config, num_labels=3)
        adversary.load_state_dict(torch.load(output_model_file))   
            
        
        
    else:
        output_model_file = os.path.join(args.output_dir, 'predictor_' + WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, 'predictor_' + CONFIG_NAME)
        config = BertConfig(output_config_file)
        predictor = BertForMultipleChoice(config, num_choices=4)
        predictor.load_state_dict(torch.load(output_model_file))
        
        output_model_file = os.path.join(args.output_dir, 'adversary_' + WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, 'adversary_' + CONFIG_NAME)
        config = BertConfig(output_config_file)
        adversary = BertForSequenceClassification(config, num_labels=2)
        adversary.load_state_dict(torch.load(output_model_file))
    predictor.to(device)
    adversary.to(device)


    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = read_swag_examples(os.path.join(args.data_dir, 'val.csv'), is_training = True)
        eval_features = convert_examples_to_features(
            eval_examples, tokenizer, args.max_seq_length, True)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'), dtype=torch.long)
        all_input_mask = torch.tensor(select_field(eval_features, 'input_mask'), dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(eval_features, 'segment_ids'), dtype=torch.long)
        all_label = torch.tensor([f.label for f in eval_features], dtype=torch.long)
        all_vp_input_ids = torch.tensor(select_field(eval_features, 'vp_input_ids'), dtype=torch.long)
        all_vp_input_mask = torch.tensor(select_field(eval_features, 'vp_input_mask'), dtype=torch.long)
        all_protected_attr = torch.tensor([f.protected_attr for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label, all_vp_input_ids, all_vp_input_mask, all_protected_attr)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        predictor.eval()
        adversary.eval()
        eval_loss_pred, eval_accuracy_pred = 0, 0
        eval_loss_adv, eval_accuracy_adv = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for input_ids, input_mask, segment_ids, label_ids, vp_input_ids, vp_input_mask, protected_attr_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            vp_input_ids = vp_input_ids.to(device)
            vp_input_mask = vp_input_mask.to(device)
            protected_attr_ids = protected_attr_ids.to(device)
            
            with torch.no_grad():
                tmp_eval_loss_pred, logits_pred = predictor(input_ids, segment_ids, input_mask, label_ids)
            predicted_vps = torch.argmax(logits_pred, dim=1)    
            predicted_vps = predicted_vps.view(-1, 1).repeat(1, vp_input_ids.size(2)).view([-1, 1, vp_input_ids.size(2)])
            vp_input_ids = torch.gather(vp_input_ids, dim=1, index=predicted_vps)
            vp_input_ids = vp_input_ids.view([vp_input_ids.size(0), -1])
            vp_input_mask = torch.gather(vp_input_mask, dim=1, index=predicted_vps)
            vp_input_mask = vp_input_mask.view([vp_input_mask.size(0), -1])
            with torch.no_grad():
                tmp_eval_loss_adv, logits_adv = adversary(vp_input_ids, None, vp_input_mask, protected_attr_ids)
            
            # print("logits_adv", logits_adv)
            tmp_eval_accuracy_pred = accuracy(logits_pred, label_ids)
            tmp_eval_accuracy_adv = accuracy(logits_adv, protected_attr_ids)
            

            eval_loss_pred += tmp_eval_loss_pred.mean().item()
            eval_accuracy_pred += tmp_eval_accuracy_pred.item()
            eval_loss_adv += tmp_eval_loss_adv.mean().item()
            eval_accuracy_adv += tmp_eval_accuracy_adv.item()

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss_pred /= nb_eval_steps
        eval_accuracy_pred /= nb_eval_examples
        eval_loss_adv /= nb_eval_steps
        eval_accuracy_adv /= nb_eval_examples
        
        if args.do_train:
            result = {'eval_loss_pred': eval_loss_pred,
                      'eval_accuracy_pred': eval_accuracy_pred,
                      'eval_loss_adv': eval_loss_adv,
                      'eval_accuracy_adv': eval_accuracy_adv,
                      'global_step': global_step,
                      'loss_pred': tr_loss_pred/nb_tr_steps,
                      'loss_adv': tr_loss_adv/nb_tr_steps}
        else:
            result = {'eval_loss_pred': eval_loss_pred,
                      'eval_accuracy_pred': eval_accuracy_pred,
                      'eval_loss_adv': eval_loss_adv,
                      'eval_accuracy_adv': eval_accuracy_adv}

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))