def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input train corpus.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        action='store_true',
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument("--hybrid_attention",
                        action='store_true',
                        help="Whether to use hybrid attention")
    parser.add_argument("--continue_training",
                        action='store_true',
                        help="Continue training from a checkpoint")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train:
        raise ValueError(
            "Training is currently the only implemented execution option. Please set `do_train`."
        )

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and not args.continue_training:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    #train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        print("Loading Train Dataset", args.train_file)
        train_dataset = BERTDataset(args.train_file,
                                    tokenizer,
                                    seq_len=args.max_seq_length,
                                    corpus_lines=None,
                                    on_memory=args.on_memory)
        num_train_optimization_steps = int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    if args.hybrid_attention:
        max_seq_length = args.max_seq_length
        attention_mask = torch.ones(12,
                                    max_seq_length,
                                    max_seq_length,
                                    dtype=torch.long)
        # left attention
        attention_mask[:2, :, :] = torch.tril(
            torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
        # right attention
        attention_mask[2:4, :, :] = torch.triu(
            torch.ones(max_seq_length, max_seq_length, dtype=torch.long))
        # local attention, window size = 3
        attention_mask[4:6, :, :] = torch.triu(
            torch.tril(
                torch.ones(max_seq_length, max_seq_length, dtype=torch.long),
                1), -1)
        # local attention, window size = 5
        attention_mask[6:8, :, :] = torch.triu(
            torch.tril(
                torch.ones(max_seq_length, max_seq_length, dtype=torch.long),
                2), -2)
        attention_mask = torch.cat(
            [attention_mask.unsqueeze(0) for _ in range(8)])
        attention_mask = attention_mask.to(device)
    else:
        attention_mask = None

    global_step = 0
    epoch_start = 0
    if args.do_train:
        if args.continue_training:
            # if checkpoint file exists, find the last checkpoint
            if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
                all_cp = os.listdir(args.output_dir)
                steps = [
                    int(re.search('_\d+', cp).group()[1:]) for cp in all_cp
                    if re.search('_\d+', cp)
                ]
                if len(steps) == 0:
                    raise ValueError(
                        "No existing checkpoint. Please do not use --continue_training."
                    )
                max_step = max(steps)
                # load checkpoint
                checkpoint = torch.load(
                    os.path.join(args.output_dir,
                                 'checkpoints_' + str(max_step) + '.pt'))
                logger.info("***** Loading checkpoint *****")
                logger.info("  Num steps = %d", checkpoint['global_step'])
                logger.info("  Num epoch = %d", checkpoint['epoch'])
                logger.info("  Loss = %d, %d", checkpoint['loss'],
                            checkpoint['loss_now'])
                model.module.load_state_dict(checkpoint['model'])
                optimizer.load_state_dict(checkpoint['optimizer'])
                global_step = checkpoint['global_step']
                epoch_start = checkpoint['epoch']
            else:
                raise ValueError(
                    "No existing checkpoint. Please do not use --continue_training."
                )

        writer = SummaryWriter(log_dir=os.environ['HOME'])
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            #TODO: check if this works with current data generator from disk that relies on next(file)
            # (it doesn't return item back by index)
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        tr_loss_1000 = 0
        for ep in trange(epoch_start, int(args.num_train_epochs),
                         desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch
                loss = model(input_ids,
                             segment_ids,
                             input_mask,
                             lm_label_ids,
                             hybrid_mask=attention_mask)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()
                tr_loss_1000 += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                # log the training loss for every 1000 steps
                if global_step % 1000 == 999:
                    writer.add_scalar('data/loss', tr_loss_1000 / 1000,
                                      global_step)
                    logger.info("training steps: %s", global_step)
                    logger.info("training loss per 1000: %s",
                                tr_loss_1000 / 1000)
                    tr_loss_1000 = 0
                # save the checkpoint for every 10000 steps
                if global_step % 10000 == 0:
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self
                    output_file = os.path.join(
                        args.output_dir,
                        "checkpoints_" + str(global_step) + ".pt")
                    checkpoint = {
                        'model': model_to_save.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': ep,
                        'global_step': global_step,
                        'loss': tr_loss / nb_tr_steps,
                        'loss_now': tr_loss_1000
                    }
                    if args.do_train:
                        torch.save(checkpoint, output_file)
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir,
                                             "pytorch_model.bin_" + str(ep))
            if args.do_train:
                torch.save(model_to_save.state_dict(), output_model_file)
            logger.info("training loss: %s", tr_loss / nb_tr_steps)

        # Save a trained model
        logger.info("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
        if args.do_train:
            torch.save(model_to_save.state_dict(), output_model_file)
Example #2
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--src_domain",
        default=None,
        type=str,
        required=True,
        help="The src train corpus. One of: books, dvd, elctronics, kitchen.")
    parser.add_argument(
        "--trg_domain",
        default=None,
        type=str,
        required=True,
        help="The trg corpus. One of: books, dvd, elctronics, kitchen.")
    parser.add_argument(
        "--bert_model",
        default='bert-base-uncased',
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default='models/books_to_electronics',
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )
    parser.add_argument("--pivot_path",
                        default=None,
                        type=str,
                        required=True,
                        help="The directory where needed pivots are."
                        "(as data/kitchen_to_books/pivots/40_bigram)")
    parser.add_argument("--pivot_prob",
                        default=0.5,
                        type=float,
                        required=True,
                        help="Probability to mask a pivot.")
    parser.add_argument("--non_pivot_prob",
                        default=0.1,
                        type=float,
                        required=True,
                        help="Probability to mask a non-pivot.")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=True,
                        type=bool,
                        help="Whether to run training.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=100.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--save_every_num_epochs",
                        default=20.0,
                        type=float,
                        help="After how many epochs to save weights.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        default=True,
        type=bool,
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--do_lower_case",
        default=True,
        type=bool,
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--num_of_unfrozen_bert_layers',
        type=int,
        default=8,
        help="Number of trainable BERT layers during pretraining.")
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--init_output_embeds',
        action='store_true',
        help="Whether to initialize pivots decoder with BERT embedding or not."
    )
    parser.add_argument('--train_output_embeds',
                        action='store_true',
                        help="Whether to train pivots decoder or not.")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    logger.info(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    print("---- Pivots Path:", args.pivot_path)
    pickle_in = open(args.pivot_path, "rb")
    pivot_list = pickle.load(pickle_in)
    pivot2id_dict = {}
    id2pivot_dict = {}
    pivot2id_dict['NONE'] = 0
    id2pivot_dict[0] = 'NONE'
    for id, feature in enumerate(pivot_list):
        pivot2id_dict[feature] = id + 1
        id2pivot_dict[id + 1] = feature

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train:
        raise ValueError(
            "Training is currently the only implemented execution option. Please set `do_train`."
        )

    dir_for_save = args.output_dir

    if os.path.exists(dir_for_save) and os.listdir(dir_for_save):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                dir_for_save))
    if not os.path.exists(dir_for_save):
        os.mkdir(dir_for_save)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    # train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        print("Loading Train Dataset from", args.src_domain, "and from",
              args.trg_domain)
        train_dataset = BERTDataset(args.src_domain,
                                    args.trg_domain,
                                    tokenizer,
                                    seq_len=args.max_seq_length,
                                    pivot2id_dict=pivot2id_dict,
                                    corpus_lines=None,
                                    on_memory=args.on_memory,
                                    pivot_prob=args.pivot_prob,
                                    non_pivot_prob=args.non_pivot_prob)
        num_train_optimization_steps = int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForMaskedLM.from_pretrained(args.bert_model,
                                            output_dim=len(pivot2id_dict),
                                            init_embed=args.init_output_embeds,
                                            src=args.src_domain,
                                            trg=args.trg_domain)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # freeze all bert weights, train only last encoder layer
    try:
        for param in model.bert.embeddings.parameters():
            param.requires_grad = False
        for id, param in enumerate(model.bert.encoder.layer.parameters()):
            if id < (192 * (12 - args.num_of_unfrozen_bert_layers) / 12):
                param.requires_grad = False
        for param in model.cls.predictions.pivots_decoder.parameters():
            param.requires_grad = args.train_output_embeds
    except:
        for param in model.module.bert.embeddings.parameters():
            param.requires_grad = False
        for id, param in enumerate(
                model.module.bert.encoder.layer.parameters()):
            if id < (192 * (12 - args.num_of_unfrozen_bert_layers) / 12):
                param.requires_grad = False
        for param in model.module.cls.predictions.pivots_decoder.parameters():
            param.requires_grad = args.train_output_embeds

    # Prepare optimizer
    if args.do_train:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(
                warmup=args.warmup_proportion,
                t_total=num_train_optimization_steps)

        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            # TODO: check if this works with current data generator from disk that relies on next(file)
            # (it doesn't return item back by index)
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for cnt in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, lm_label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                # TODO delete
                # if step == 100:
                #     print(f"{'='*10} Breaking the Training {'='*10}")
                #     break
            if (((cnt + 1) % args.save_every_num_epochs) == 0):
                # Save a trained model
                logger.info("** ** * Saving fine - tuned model ** ** * ")
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(
                    dir_for_save, "pytorch_model" + str(cnt + 1) + ".bin")
                if args.do_train:
                    torch.save(model_to_save.state_dict(), output_model_file)
        logger.info("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(dir_for_save,
                                         "pytorch_model" + ".bin")
        if args.do_train:
            torch.save(model_to_save.state_dict(), output_model_file)
Example #3
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run test on the test set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mnli-mm": MnliMismatchedProcessor,
        "mrpc": MrpcProcessor,
        "sst-2": Sst2Processor,
        "sts-b": StsbProcessor,
        "qqp": QqpProcessor,
        "qnli": QnliProcessor,
        "rte": RteProcessor,
        "wnli": WnliProcessor,
    }

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification",
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_test:
        raise ValueError("At least one of `do_train` or `do_eval`  or 'do_test' must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
              cache_dir=cache_dir,num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    if args.do_train:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [
            {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
            {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
            ]
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
                                                 t_total=num_train_optimization_steps)

        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
        tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)
    else:
        model = BertForSequenceClassification.from_pretrained(args.output_dir, num_labels=num_labels)
    model.to(device)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            # create eval loss and other metric required by the task
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
            elif output_mode == "regression":
                loss_fct = MSELoss()
                tmp_eval_loss = loss_fct(logits.view(-1), label_ids.view(-1))
            
            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(preds)
        result = compute_metrics(task_name, preds, all_label_ids.numpy())
        loss = tr_loss/global_step if args.do_train else None

        result['eval_loss'] = eval_loss
        result['global_step'] = global_step
        result['loss'] = loss

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))










    if args.do_test and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)

       

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []

        for input_ids, input_mask, segment_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)
            #print(preds[0])
            #lal=input("hold on please")

        preds = preds[0]
        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)
        elif output_mode == "regression":
            preds = np.squeeze(preds)
        print(preds,"hold on")

        with open('/DATA5_DB8/data/tfhou/QNLI.tsv', 'w') as outfile:
            outfile.write('index\tprediction\n')
            for idx, prob in enumerate(preds):
                label = 'entailment' if prob == 0 else 'not_entailment'
                
                outfile.write(str(idx) + '\t' + label + '\n')














        # hack for MNLI-MM
        if task_name == "mnli":
            task_name = "mnli-mm"
            processor = processors[task_name]()

            if os.path.exists(args.output_dir + '-MM') and os.listdir(args.output_dir + '-MM') and args.do_train:
                raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
            if not os.path.exists(args.output_dir + '-MM'):
                os.makedirs(args.output_dir + '-MM')

            eval_examples = processor.get_dev_examples(args.data_dir)
            eval_features = convert_examples_to_features(
                eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
            logger.info("***** Running evaluation *****")
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)
            all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
            all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

            model.eval()
            eval_loss = 0
            nb_eval_steps = 0
            preds = []

            for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(input_ids, segment_ids, input_mask, labels=None)
            
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels), label_ids.view(-1))
            
                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_steps += 1
                if len(preds) == 0:
                    preds.append(logits.detach().cpu().numpy())
                else:
                    preds[0] = np.append(
                        preds[0], logits.detach().cpu().numpy(), axis=0)

            eval_loss = eval_loss / nb_eval_steps
            preds = preds[0]
            preds = np.argmax(preds, axis=1)
            result = compute_metrics(task_name, preds, all_label_ids.numpy())
            loss = tr_loss/global_step if args.do_train else None

            result['eval_loss'] = eval_loss
            result['global_step'] = global_step
            result['loss'] = loss

            output_eval_file = os.path.join(args.output_dir + '-MM', "eval_results.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
def main(model_path):
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--bert_model", default="bert-base-uncased", type=str,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                             "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default="MRPC",
                        type=str,
                        help="The name of the task to train.")
    parser.add_argument("--testing_file", type=str)
    parser.add_argument("--predict_file", type=str)

    ## Other parameters
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_eval",
                        default=True,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        default=True,
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=2,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=4,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=2e-4,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=20.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    args = parser.parse_args()

    processors = {"mrpc": MrpcProcessor}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = "classification"

    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank))
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
                                                          cache_dir=cache_dir,
                                                          num_labels=num_labels)
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]


    state_dict = torch.load(model_path, map_location=device)
    model = BertForSequenceClassification.from_pretrained('bert-base-uncased'
                                                          , state_dict=state_dict, cache_dir=cache_dir, num_labels=5)

    model.to(device)

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.testing_file)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer, output_mode)
        logger.info("***** Running Testing *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=4)

        model.eval()

        print('[Start prediction!]')
        predict_result=[]
        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Testing"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            predict_result.append(torch.argmax(logits, 1))

        result_series = []
        for result in predict_result:
            for each in result.cpu().numpy():
                result_series.append(each)

        Ids = [id for id in range(20001, 22210 + 1)]
        save_predictions(Ids, result_series, args.predict_file)
Example #5
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--output_dir", default=None, type=str, required=True,
                        help="The output directory where the model checkpoints and predictions will be written.")

    ## Other parameters
    parser.add_argument("--train_file", default=None, type=str, help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument("--predict_file", default=None, type=str,
                        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json")
    parser.add_argument("--max_seq_length", default=384, type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. Sequences "
                             "longer than this will be truncated, and sequences shorter than this will be padded.")
    parser.add_argument("--doc_stride", default=128, type=int,
                        help="When splitting up a long document into chunks, how much stride to take between chunks.")
    parser.add_argument("--max_query_length", default=64, type=int,
                        help="The maximum number of tokens for the question. Questions longer than this will "
                             "be truncated to this length.")
    parser.add_argument("--do_train", action='store_true', help="Whether to run training.")
    parser.add_argument("--do_predict", action='store_true', help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size", default=32, type=int, help="Total batch size for training.")
    parser.add_argument("--predict_batch_size", default=8, type=int, help="Total batch size for predictions.")
    parser.add_argument("--learning_rate", default=5e-5, type=float, help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs", default=3.0, type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion", default=0.1, type=float,
                        help="Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10%% "
                             "of training.")
    parser.add_argument("--n_best_size", default=20, type=int,
                        help="The total number of n-best predictions to generate in the nbest_predictions.json "
                             "output file.")
    parser.add_argument("--max_answer_length", default=30, type=int,
                        help="The maximum length of an answer that can be generated. This is needed because the start "
                             "and end predictions are not conditioned on one another.")
    parser.add_argument("--verbose_logging", action='store_true',
                        help="If true, all of the warnings related to data processing will be printed. "
                             "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Whether to lower case the input text. True for uncased models, False for cased models.")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--version_2_with_negative',
                        action='store_true',
                        help='If true, the SQuAD examples contain some that do not have an answer.')
    parser.add_argument('--null_score_diff_threshold',
                        type=float, default=0.0,
                        help="If null_score - best_non_null is greater than the threshold predict null.")
    parser.add_argument('--server_ip', type=str, default='', help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='', help="Can be used for distant debugging.")
    args = parser.parse_args()
    print(args)

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port), redirect_output=True)
        ptvsd.wait_for_attach()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.ERROR if args.local_rank in [-1, 0] else logging.WARN)

    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError("At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory () already exists and is not empty.")
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = read_squad_examples(
            input_file=args.train_file, is_training=True, version_2_with_negative=args.version_2_with_negative)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()

    # Prepare model
    model = BertForQuestionAnswering.from_pretrained(args.bert_model,
                cache_dir=os.path.join(str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(args.local_rank)))

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]

    if args.do_train:
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
                                                 t_total=num_train_optimization_steps)
        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

    global_step = 0
    if args.do_train:
        cached_train_features_file = args.train_file+'_{0}_{1}_{2}_{3}'.format(
            list(filter(None, args.bert_model.split('/'))).pop(), str(args.max_seq_length), str(args.doc_stride), str(args.max_query_length))
        train_features = None
        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except:
            train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s", cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                                   all_start_positions, all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration", disable=args.local_rank not in [-1, 0])):
                if n_gpu == 1:
                    batch = tuple(t.to(device) for t in batch) # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions = batch
                loss = model(input_ids, segment_ids, input_mask, start_positions, end_positions)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used and handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(global_step/num_train_optimization_steps,
                                                                                 args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    if args.do_train and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = BertForQuestionAnswering.from_pretrained(args.output_dir)
        tokenizer = BertTokenizer.from_pretrained(args.output_dir, do_lower_case=args.do_lower_case)

    model.to(device)

    if args.do_predict and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = read_squad_examples(
            input_file=args.predict_file, is_training=False, version_2_with_negative=args.version_2_with_negative)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(eval_examples))
        logger.info("  Num split examples = %d", len(eval_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0), dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_example_index)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices in tqdm(eval_dataloader, desc="Evaluating", disable=args.local_rank not in [-1, 0]):
            if len(all_results) % 1000 == 0:
                logger.info("Processing example: %d" % (len(all_results)))
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            with torch.no_grad():
                batch_start_logits, batch_end_logits = model(input_ids, segment_ids, input_mask)
            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()
                eval_feature = eval_features[example_index.item()]
                unique_id = int(eval_feature.unique_id)
                all_results.append(RawResult(unique_id=unique_id,
                                             start_logits=start_logits,
                                             end_logits=end_logits))
        output_prediction_file = os.path.join(args.output_dir, "predictions.json")
        output_nbest_file = os.path.join(args.output_dir, "nbest_predictions.json")
        output_null_log_odds_file = os.path.join(args.output_dir, "null_odds.json")
        write_predictions(eval_examples, eval_features, all_results,
                          args.n_best_size, args.max_answer_length,
                          args.do_lower_case, output_prediction_file,
                          output_nbest_file, output_null_log_odds_file, args.verbose_logging,
                          args.version_2_with_negative, args.null_score_diff_threshold)
Example #6
0
def main():
    def evaluate(dataloader, export=None):
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        logits_list = []
        iter_idx = 0
        corr_x = []
        corr_y = []
        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids,
                                      segment_ids,
                                      input_mask,
                                      label_ids,
                                      mse=is_float)
                logits = model(input_ids,
                               segment_ids,
                               input_mask,
                               mse=is_float)

            logits = logits.detach().cpu().numpy()
            if export is not None:
                logits_list.append(logits)
            label_ids = label_ids.to('cpu').numpy()
            if is_float:
                corr_x.extend(logits.flatten())
                corr_y.extend(label_ids.flatten())
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1
            # if (iter_idx + 1) % 1000 == 0 and export is not None:
            #     torch.save((iter_idx, logits_list), export)
            iter_idx += 1
        if export is not None:
            torch.save(logits_list, export)

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None
        if is_float:
            print(pearsonr(corr_x, corr_y))
            print(spearmanr(corr_x, corr_y))
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': loss
        }
        return result

    local_rank = -1
    parser = argparse.ArgumentParser()
    parser.add_argument("--config", "-c", type=str, required=True)
    args, _ = parser.parse_known_args()
    options = argconf.options_from_json("confs/options.json")
    config = argconf.config_from_json(args.config)
    args = edict(argconf.parse_args(options, config))
    print(f"Using config: {args}")
    set_seed(args.seed)
    args.do_train = args.do_train and not args.do_test_only

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
        "sst2": SST2Processor,
        'qnli': QnliProcessor,
        'rte': RteProcessor,
        "imdb": IMDBSentenceProcessor,
        "raw_single":
        IMDBSentenceProcessor,  # This is not a mistake, just poor naming.
        "qqp": QuoraProcessor,
        "sts": STSProcessor,
        "raw_sts_pair": RawSTSPairProcessor,
        "raw_pair": RawPairProcessor
    }

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # if os.path.exists(args.workspace) and os.listdir(args.workspace) and args.do_train:
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.workspace))
    if not os.path.exists(args.workspace):
        os.makedirs(args.workspace)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = args.n_labels
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.model_file,
                                              do_lower_case=args.uncased)

    num_train_optimization_steps = None
    train_examples = processor.get_train_examples(args.data_dir)
    num_train_optimization_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps) * args.num_train_epochs
    if local_rank != -1:
        num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
        )

    # Prepare model
    cache_dir = os.path.join(PYTORCH_PRETRAINED_BERT_CACHE,
                             'distributed_{}'.format(local_rank))
    model = BertForSequenceClassification.from_pretrained(
        args.model_file, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    # sd = torch.load('qqp.pt')
    # sd = torch.load('sts.pt')
    # del sd['classifier.weight']
    # del sd['classifier.bias']
    # model.load_state_dict(sd, strict=False)
    if local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    param_optimizer = list(
        filter(
            lambda x: x[0] in
            ("module.classifier.weight", "module.classifier.bias"),
            param_optimizer))
    print(len(param_optimizer))
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    train_features = convert_examples_to_features(train_examples, label_list,
                                                  args.max_seq_length,
                                                  tokenizer)
    is_float = isinstance(train_features[0].label_id, float)
    all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                 dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                  dtype=torch.long)
    all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                   dtype=torch.long)
    all_label_ids = torch.tensor([f.label_id for f in train_features],
                                 dtype=torch.float if is_float else torch.long)
    train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids,
                               all_label_ids)
    if local_rank == -1:
        train_sampler = RandomSampler(train_data)
    else:
        train_sampler = DistributedSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    # BEGIN SST-2 -> QQP experiments
    # END   SST-2 -> QQP experiments
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids,
                             segment_ids,
                             input_mask,
                             label_ids,
                             mse=is_float)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    output_model_file = os.path.join(args.workspace, WEIGHTS_NAME)
    if args.do_train:
        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.workspace, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    elif args.do_test_only:
        convert = convert_single_to_dp if isinstance(
            model, torch.nn.DataParallel) else convert_dp_to_single
        model.load_state_dict(convert(torch.load(output_model_file)))
    else:
        # pass
        model = BertForSequenceClassification.from_pretrained(
            args.model_file, num_labels=num_labels)
    model.to(device)

    if args.export:
        model.eval()
        train_dataloader = DataLoader(train_data,
                                      batch_size=args.eval_batch_size,
                                      shuffle=False)
        with torch.no_grad():
            evaluate(train_dataloader, export=args.export)
        return

    if args.visualize:
        model.eval()
        train_dataloader = DataLoader(train_data,
                                      batch_size=args.eval_batch_size,
                                      shuffle=False)
        with open(os.path.join(args.workspace, "viz_results.csv"), "w") as f:
            writer = None

    if args.do_eval and (local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_test_examples(
            args.data_dir
        ) if args.do_test_only else processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor(
            [f.label_id for f in eval_features],
            dtype=torch.long
            if isinstance(eval_features[0].label_id, int) else torch.float)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        result = evaluate(eval_dataloader)

        output_eval_file = os.path.join(args.workspace, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Example #7
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--bert_original",
                        action='store_true',
                        help="To run for original BERT")
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    args = parser.parse_args()

    processors = {
        "nsp": NSPProcessor,
    }

    num_labels_task = {
        "nsp": 2,
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=True)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size / args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size()


    cache_dir = args.cache_dir if args.cache_dir else os.path.join(PYTORCH_PRETRAINED_BERT_CACHE, 'distributed_{}'.format(args.local_rank))
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
            cache_dir=cache_dir,
            num_labels = num_labels)
    print('BERT original model loaded')
        
    
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(
            train_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features], dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data, sampler=train_sampler, batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
    
    # save model
    torch.save(model.state_dict(), os.path.join(args.output_dir, 'nsp_model.pt'))

    if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
 
        for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask, label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss/nb_tr_steps if args.do_train else None
        result = {'eval_loss': eval_loss,
                  'eval_accuracy': eval_accuracy,
                  'global_step': global_step,
                  'loss': loss}

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Example #8
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--ernie_model",
                        default=None,
                        type=str,
                        required=True,
                        help="Ernie pre-trained model")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--threshold', type=float, default=.3)

    args = parser.parse_args()

    processors = TacredProcessor

    num_labels_task = 80

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    processor = processors()
    label_list = None

    tokenizer = BertTokenizer.from_pretrained(args.ernie_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    train_examples, label_list = processor.get_train_examples(args.data_dir)
    num_labels = len(label_list)

    num_train_steps = int(
        len(train_examples) / args.train_batch_size /
        args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    model, _ = BertForSequenceClassification.from_pretrained(
        args.ernie_model,
        cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
        'distributed_{}'.format(args.local_rank),
        num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_grad = [
        'bert.encoder.layer.11.output.dense_ent',
        'bert.encoder.layer.11.output.LayerNorm_ent'
    ]
    param_optimizer = [(n, p) for n, p in param_optimizer
                       if not any(nd in n for nd in no_grad)]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer,
                                                      args.threshold)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        vecs = []
        vecs.append([0] * 100)
        with open("kg_embed/entity2vec.vec", 'r') as fin:
            for line in fin:
                vec = line.strip().split('\t')
                vec = [float(x) for x in vec]
                vecs.append(vec)
        embed = torch.FloatTensor(vecs)
        embed = torch.nn.Embedding.from_pretrained(embed)
        #embed = torch.nn.Embedding(5041175, 100)

        logger.info("Shape of entity embedding: " + str(embed.weight.size()))
        del vecs

        # zeros = [0 for _ in range(args.max_seq_length)]
        # zeros_ent = [0 for _ in range(100)]
        # zeros_ent = [zeros_ent for _ in range(args.max_seq_length)]
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        all_ent = torch.tensor([f.input_ent for f in train_features],
                               dtype=torch.long)
        all_ent_masks = torch.tensor([f.ent_mask for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_ent, all_ent_masks,
                                   all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        output_loss_file = os.path.join(args.output_dir, "loss")
        loss_fout = open(output_loss_file, 'w')
        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(
                    t.to(device) if i != 3 else t for i, t in enumerate(batch))
                input_ids, input_mask, segment_ids, input_ent, ent_mask, label_ids = batch
                input_ent = embed(input_ent + 1).to(device)  # -1 -> 0
                loss = model(input_ids, segment_ids, input_mask,
                             input_ent.half(), ent_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                loss_fout.write("{}\n".format(loss.item()))
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
            model_to_save = model.module if hasattr(model, 'module') else model
            output_model_file = os.path.join(
                args.output_dir, "pytorch_model.bin_{}".format(global_step))
            torch.save(model_to_save.state_dict(), output_model_file)

        # Save a trained model
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
        torch.save(model_to_save.state_dict(), output_model_file)
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--vocab_file",
                        default='bert-base-uncased-vocab.txt',
                        type=str,
                        required=True)
    parser.add_argument("--model_file",
                        default='bert-base-uncased.tar.gz',
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--predict_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the predictions will be written.")
    parser.add_argument('--predict_output_file',
                        type=str,
                        default='predictions.json')
    parser.add_argument('--label_output_file',
                        type=str,
                        default='evidence_predictions.json')

    # Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    # Base setting
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--max_ctx', type=int, default=2)
    parser.add_argument('--task_name', type=str, default='coqa_yesno')
    parser.add_argument('--bert_name', type=str, default='baseline')
    parser.add_argument('--reader_name', type=str, default='coqa')
    # model parameters
    parser.add_argument('--evidence_lambda', type=float, default=0.8)
    parser.add_argument('--negative_lambda', type=float, default=1.0)
    parser.add_argument('--add_entropy', default=False, action='store_true')
    parser.add_argument('--split_num', type=int, default=3)
    parser.add_argument('--split_index', type=int, default=0)
    # Parameters for running labeling model
    parser.add_argument('--do_label', default=False, action='store_true')
    parser.add_argument('--sentence_id_file', type=str, default=None)
    parser.add_argument('--weight_threshold', type=float, default=0.0)
    parser.add_argument('--label_threshold', type=float, default=0.0)
    # negative sample parameters
    parser.add_argument('--do_negative_sampling',
                        default=False,
                        action='store_true')
    parser.add_argument('--read_extra_self',
                        default=False,
                        action='store_true')
    parser.add_argument('--sample_ratio', type=float, default=0.5)
    parser.add_argument('--extra_sen_file', type=str, default=None)
    parser.add_argument('--multi_inputs', default=False, action='store_true')

    args = parser.parse_args()

    logger = setting_logger(args.output_dir)
    logger.info('================== Program start. ========================')

    # model parameters
    model_params = prepare_model_params(args)

    # read parameters
    read_params = prepare_read_params(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if args.do_train:
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory () already exists and is not empty.")
        os.makedirs(args.output_dir, exist_ok=True)

    if args.do_predict:
        os.makedirs(args.predict_dir, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.vocab_file)

    data_reader = initialize_reader(args.reader_name)

    num_train_steps = None
    if args.do_train or args.do_label:
        train_examples = data_reader.read(input_file=args.train_file,
                                          **read_params)

        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_ctx), str(args.task_name))

        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except FileNotFoundError:
            train_features = data_reader.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)

        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Just for test
    no_evidence = 0
    for feature in train_features:
        if feature.sentence_id == -1:
            no_evidence += 1
    logger.info(
        f'No evidence ratio: {no_evidence} / {len(train_features)} = {no_evidence * 1.0 / len(train_features)}'
    )

    # Prepare model
    if args.pretrain is not None:
        logger.info('Load pretrained model from {}'.format(args.pretrain))
        model_state_dict = torch.load(args.pretrain, map_location='cuda:0')
        model = initialize_model(args.bert_name,
                                 args.model_file,
                                 state_dict=model_state_dict,
                                 **model_params)
    else:
        model = initialize_model(args.bert_name, args.model_file,
                                 **model_params)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    # Prepare data
    if 'read_state' in read_params:
        read_params['read_state'] = ReadState.NoNegative
    eval_examples = data_reader.read(input_file=args.predict_file,
                                     **read_params)
    eval_features = data_reader.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length,
        is_training=False)

    eval_tensors = data_reader.data_to_tensors(eval_features)
    eval_data = TensorDataset(*eval_tensors)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.predict_batch_size)

    if args.do_train:

        if args.sentence_id_file is not None:
            logger.info('Training with evidence self-labeled data.')
            data_reader.generate_features_sentence_ids(train_features,
                                                       args.sentence_id_file)
        else:
            logger.info('No sentence id file found. Train in traditional way.')

        logger.info("Start training")
        train_loss = AverageMeter()
        best_acc = 0.0
        summary_writer = SummaryWriter(log_dir=args.output_dir)
        global_step = 0
        eval_loss = AverageMeter()

        train_tensors = data_reader.data_to_tensors(train_features)
        train_data = TensorDataset(*train_tensors)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            # Train
            model.train()
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, train_features, model_state=ModelState.Train)
                output_dict = model(**inputs)
                loss = output_dict['loss']
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                train_loss.update(loss.item(), args.train_batch_size)
                summary_writer.add_scalar('train_loss', train_loss.avg,
                                          global_step)

            # Evaluation
            model.eval()
            all_results = []
            logger.info("Start evaluating")
            for eval_step, batch in enumerate(
                    tqdm(eval_dataloader, desc="Evaluating")):
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, eval_features, model_state=ModelState.Evaluate)
                with torch.no_grad():
                    output_dict = model(**inputs)
                    loss, batch_choice_logits = output_dict[
                        'loss'], output_dict['yesno_logits']
                    eval_loss.update(loss.item(), args.predict_batch_size)
                    summary_writer.add_scalar(
                        'eval_loss', eval_loss.avg,
                        epoch * len(eval_dataloader) + eval_step)
                example_indices = batch[-1]
                for i, example_index in enumerate(example_indices):
                    choice_logits = batch_choice_logits[i].detach().cpu(
                    ).tolist()

                    eval_feature = eval_features[example_index.item()]
                    unique_id = int(eval_feature.unique_id)
                    all_results.append(
                        RawResultChoice(unique_id=unique_id,
                                        choice_logits=choice_logits))

            data_reader.write_predictions(eval_examples,
                                          eval_features,
                                          all_results,
                                          None,
                                          null_score_diff_threshold=0.0)
            yes_metric = data_reader.yesno_cate.f1_measure('yes', 'no')
            no_metric = data_reader.yesno_cate.f1_measure('no', 'yes')
            current_acc = yes_metric['accuracy']
            summary_writer.add_scalar('eval_yes_f1', yes_metric['f1'], epoch)
            summary_writer.add_scalar('eval_yes_recall', yes_metric['recall'],
                                      epoch)
            summary_writer.add_scalar('eval_yes_precision',
                                      yes_metric['precision'], epoch)
            summary_writer.add_scalar('eval_no_f1', no_metric['f1'], epoch)
            summary_writer.add_scalar('eval_no_recall', no_metric['recall'],
                                      epoch)
            summary_writer.add_scalar('eval_no_precision',
                                      no_metric['precision'], epoch)
            summary_writer.add_scalar('eval_yesno_acc', current_acc, epoch)
            torch.cuda.empty_cache()

            if current_acc > best_acc:
                best_acc = current_acc
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(args.output_dir,
                                                 "pytorch_model.bin")
                torch.save(model_to_save.state_dict(), output_model_file)
            logger.info('Epoch: %d, Accuracy: %f (Best Accuracy: %f)' %
                        (epoch, current_acc, best_acc))
            data_reader.yesno_cate.reset()

        summary_writer.close()

    # Loading trained model.
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    model_state_dict = torch.load(output_model_file, map_location='cuda:0')
    model = initialize_model(args.bert_name,
                             args.model_file,
                             state_dict=model_state_dict,
                             **model_params)
    model.to(device)

    # Write Yes/No predictions
    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):

        test_examples = eval_examples
        test_features = eval_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start predicting yes/no on Dev set.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(
                    batch, device)  # multi-gpu does scattering it-self
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits = output_dict['yesno_logits']
            example_indices = batch[-1]
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu().tolist()

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawResultChoice(unique_id=unique_id,
                                    choice_logits=choice_logits))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'predictions.json')
        data_reader.write_predictions(eval_examples,
                                      eval_features,
                                      all_results,
                                      output_prediction_file,
                                      null_score_diff_threshold=0.0)
        yes_metric = data_reader.yesno_cate.f1_measure('yes', 'no')
        no_metric = data_reader.yesno_cate.f1_measure('no', 'yes')
        logger.info('Yes Metrics: %s' % json.dumps(yes_metric, indent=2))
        logger.info('No Metrics: %s' % json.dumps(no_metric, indent=2))

    # Labeling sentence id.
    if args.do_label and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):

        test_examples = train_examples
        test_features = train_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running labeling *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start labeling.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(batch, device)
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits = output_dict['yesno_logits']
                batch_max_weight_indexes = output_dict['max_weight_index']
                batch_max_weights = output_dict['max_weight']
            example_indices = batch[-1]
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu().tolist()
                max_weight_index = batch_max_weight_indexes[i].detach().cpu(
                ).tolist()
                max_weight = batch_max_weights[i].detach().cpu().tolist()

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    FullResult(unique_id=unique_id,
                               choice_logits=choice_logits,
                               max_weight_index=max_weight_index,
                               max_weight=max_weight))

        output_prediction_file = os.path.join(args.predict_dir,
                                              args.label_output_file)
        data_reader.write_sentence_predictions(
            test_examples,
            test_features,
            all_results,
            output_prediction_file,
            weight_threshold=args.weight_threshold,
            label_threshold=args.label_threshold)
def main():
    parser = argparse.ArgumentParser()
    
    ## Required parameters
    parser.add_argument("--corpus_type",
                        default="mixed",
                        type=str,
                        required=True,
                        help="Corpus type, mixed or categories")
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain.pkl files, "
                             "named: train_data.pkl, dev_data.pkl, "
                             "test_data.pkl and mlb.pkl (e.g. as in "
                             "`exps-data/data`).")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: "
                             "bert-base-german-cased, bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, "
                             "bert-large-cased, bert-base-multilingual-uncased, "
                             "bert-base-multilingual-cased, "
                             "bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model "
                             "predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--use_data",
                        default="orig",
                        type=str,
                        help="Original DE, tokenized DE or tokenized EN.")
    parser.add_argument("--cache_dir",
                        default="",
                        type=str,
                        help="Where do you want to store the pre-trained "
                             "models downloaded from s3")
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after "
                             "WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, "
                             "and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--loss_fct",
                        default="bbce",
                        type=str,
                        help="Loss function to use BCEWithLogitsLoss (`bbce`)")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear "
                             "learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before "
                             "performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead "
                             "of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. "
                             "Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.")
    parser.add_argument('--server_ip', type=str, default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port', type=str, default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()
    
    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/
        # debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()
    
    processors = {
        "nts": NTSTaskProcessor
    }
    
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and
                                        not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of
        # sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    
    logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s - '
                                 '%(message)s',
                        datefmt = '%m/%d/%Y %H:%M:%S',
                        level = logging.INFO if args.local_rank in [-1, 0]
                        else logging.WARN)
    
    logger.info("device: {} n_gpu: {}, distributed training: {}, "
                "16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))
    
    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, "
                         "should be >= 1".format(
                            args.gradient_accumulation_steps))
    
    args.train_batch_size = args.train_batch_size // \
                            args.gradient_accumulation_steps
    
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)
    
    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must "
                         "be True.")
    
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) \
            and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not "
                         "empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    
    task_name = args.task_name.lower()
    
    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))
    
    processor = processors[task_name](args.data_dir,
                                      args.corpus_type, use_data=args.use_data)
    pos_weight = torch.tensor(processor.pos_weight, requires_grad=False,
                              dtype=torch.float, device=device)
    label_list = processor.get_labels()
    num_labels = len(label_list)
    
    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)
    
    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples()
        num_train_optimization_steps = int(
            len(train_examples) /
            args.train_batch_size /
            args.gradient_accumulation_steps
        ) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // \
                                           torch.distributed.get_world_size()
    
    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE),
        'distributed_{}'.format(args.local_rank))
    model = BertForMultiLabelSequenceClassification.from_pretrained(
        args.bert_model,
        cache_dir=cache_dir,
        num_labels=num_labels,
        loss_fct=args.loss_fct
    )
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https:/"
                              "/www.github.com/nvidia/apex to use distributed "
                              "and fp16 training.")
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)
    
    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(
            nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(
            nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    if args.do_train:
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError("Please install apex from "
                                  "https://www.github.com/nvidia/apex to use "
                                  "distributed and fp16 training.")
            
            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(
                warmup=args.warmup_proportion,
                t_total=num_train_optimization_steps)
        
        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps,
                                 schedule='warmup_cosine')
    
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    def eval(epoch=None):
        eval_examples = processor.get_dev_examples()
        eval_features = convert_examples_to_features(
            eval_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_doc_ids = torch.tensor([f.guid for f in eval_features],
                                   dtype=torch.long)
        
        # output_mode == "classification":
        all_label_ids = torch.tensor([f.label_ids for f in eval_features],
                                     dtype=torch.float)
        all_label_ids = all_label_ids.view(-1, num_labels)
        
        eval_data = TensorDataset(all_input_ids,
                                  all_input_mask,
                                  all_segment_ids,
                                  all_label_ids,
                                  all_doc_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)
        
        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        ids = []
        # FIXME: make it flexible to accept path
        all_ids_dev = read_ids(os.path.join(args.data_dir,
                                            "ids_development.txt"))
        
        for input_ids, input_mask, segment_ids, label_ids, doc_ids in \
                tqdm(eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)
            doc_ids = doc_ids.to(device)
            
            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)
            
            # create eval loss and other metric required by the task
            # output_mode == "classification":
            loss_fct = BCEWithLogitsLoss()
            tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                     label_ids.view(-1, num_labels))
            
            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)
            if len(ids) == 0:
                ids.append(doc_ids.detach().cpu().numpy())
            else:
                ids[0] = np.append(
                    ids[0], doc_ids.detach().cpu().numpy(), axis=0)

        eval_loss = eval_loss / nb_eval_steps
        ids = ids[0]
        preds = sigmoid(preds[0])
        preds = (preds > 0.5).astype(int)
        
        result = compute_metrics(task_name, preds, all_label_ids.numpy())
        #result = compute_metrics(task_name, preds, all_label_ids.numpy())
        loss = tr_loss/nb_tr_steps if args.do_train else None

        result['train_loss'] = loss
        result['eval_loss'] = eval_loss
        result['global_step'] = global_step
        
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "a") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
            writer.write('\n')
        
        with open(os.path.join(args.data_dir, f"mlb_{args.corpus_type}.pkl"),
                  "rb") as rf:
            mlb = pkl.load(rf)
        preds = [mlb.classes_[preds[i, :].astype(bool)].tolist()
                 for i in range(preds.shape[0])]
        id2preds = {val:preds[i] for i, val in enumerate(ids)}
        preds = [id2preds[val] if val in id2preds else []
                 for i, val in enumerate(all_ids_dev)]
        
        with open(os.path.join(args.output_dir, f"preds_development"
                                                f"{epoch}.txt"),
                  "w") as wf:
            for idx, doc_id in enumerate(all_ids_dev):
                line = str(doc_id) + "\t" + "|".join(preds[idx]) + "\n"
                wf.write(line)

    def predict(epoch=None):
        test_examples = processor.get_test_examples()
        test_features = convert_examples_to_features(
            test_examples, label_list, args.max_seq_length, tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(test_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)
        all_doc_ids = torch.tensor([f.guid for f in test_features],
                                   dtype=torch.long)
        
        test_data = TensorDataset(all_input_ids,
                                  all_input_mask,
                                  all_segment_ids,
                                  all_doc_ids)
        # Run prediction for full data
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.eval_batch_size)
        
        model.eval()
        eval_loss = 0
        nb_eval_steps = 0
        preds = []
        ids = []
        # FIXME: make it flexible to accept path
        all_ids_test = read_ids(os.path.join(args.data_dir, "ids_testing.txt"))
        
        for input_ids, input_mask, segment_ids, doc_ids in \
                tqdm(test_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            doc_ids = doc_ids.to(device)
            
            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)
            
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(
                    preds[0], logits.detach().cpu().numpy(), axis=0)
            if len(ids) == 0:
                ids.append(doc_ids.detach().cpu().numpy())
            else:
                ids[0] = np.append(
                    ids[0], doc_ids.detach().cpu().numpy(), axis=0)
        
        ids = ids[0]
        preds = sigmoid(preds[0])
        preds = (preds > 0.5).astype(int)
        id2preds = {val:preds[i] for i, val in enumerate(ids)}
        
        for i, val in enumerate(all_ids_test):
            if val not in id2preds:
                id2preds[val] = []
        
        with open(os.path.join(args.data_dir, f"mlb_{args.corpus_type}.pkl"),
                  "rb") as rf:
            mlb = pkl.load(rf)

        preds = [mlb.classes_[preds[i, :].astype(bool)].tolist()
                 for i in range(preds.shape[0])]
        id2preds = {val:preds[i] for i, val in enumerate(ids)}
        preds = [id2preds[val] if val in id2preds else []
                 for i, val in enumerate(all_ids_test)]
        
        with open(os.path.join(args.output_dir, f"preds_test{epoch}.txt"),
                  "w") as\
                wf:
            for idx, doc_id in enumerate(all_ids_test):
                line = str(doc_id) + "\t" + "|".join(preds[idx]) + "\n"
                wf.write(line)
    
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        
        # output_mode == "classification":
        all_label_ids = torch.tensor([f.label_ids for f in train_features],
                                     dtype=torch.float)
        all_label_ids = all_label_ids.view(-1, num_labels)
        
        train_data = TensorDataset(all_input_ids,
                                   all_input_mask,
                                   all_segment_ids,
                                   all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        
        model.train()
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_dataloader,
                                              desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                
                # define a new function to compute loss values for both
                # output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)
                
                # if output_mode == "classification":
                loss_fct = BCEWithLogitsLoss()
                loss = loss_fct(logits.view(-1, num_labels),
                                label_ids.view(-1, num_labels))

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles
                        # this automatically
                        lr_this_step = args.learning_rate * \
                                       warmup_linear.get_lr(
                            global_step/num_train_optimization_steps,
                            args.warmup_proportion
                        )
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            eval(epoch=epoch)
            predict(epoch=epoch)

            # save checkpoints
            # Save a trained model, configuration and tokenizer
            # model_to_save = model.module if hasattr(model,
            #                                         'module') else model
            # # If we save using the predefined names, we can load using
            # # `from_pretrained`
            # os.makedirs(f"{args.output_dir}/{epoch}")
            # output_model_file = os.path.join(f"{args.output_dir}/{epoch}", "
            #                                  f"WEIGHTS_NAME)
            # output_config_file = os.path.join(f"{args.output_dir}/{epoch}",
            # CONFIG_NAME)
            #
            # torch.save(model_to_save.state_dict(), output_model_file)
            # model_to_save.config.to_json_file(output_config_file)
            # tokenizer.save_vocabulary(f"{args.output_dir}/{epoch}")
            # end save checkpoints
    
    if args.do_train and (args.local_rank == -1 or
                          torch.distributed.get_rank() == 0):
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(model, 'module') else model
        
        # If we save using the predefined names, we can load using
        # `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        
        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)
        
        # Load a trained model and vocabulary that you have fine-tuned
        model = BertForMultiLabelSequenceClassification.from_pretrained(
            args.output_dir,
            num_labels=num_labels)
        tokenizer = BertTokenizer.from_pretrained(
            args.output_dir,
            do_lower_case=args.do_lower_case)
    else:
        model = BertForMultiLabelSequenceClassification.from_pretrained(
            args.bert_model,
            num_labels=num_labels)
    model.to(device)
    
    if args.do_eval and (args.local_rank == -1 or
                         torch.distributed.get_rank() == 0):
        eval()
        predict()
Example #11
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=330,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run test on the test set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    output_modes = {
        "cola": "classification",
        "mnli": "classification",
        "mrpc": "classification",
        "sst-2": "classification",
        "sts-b": "regression",
        "qqp": "classification",
        "qnli": "classification",
        "rte": "classification",
        "wnli": "classification",
    }

    # '''. prepare data
    args.stance_dir = args.data_dir
    args.ir_dir = "/home/Vachel/github/pytorch-pretrained-BERT_old/data/source"

    # tokenize = lambda x: [w for w in x]
    # text_field = data.Field(sequential=True, tokenize=tokenize, lower=True, batch_first = True)
    # label_field = data.Field(sequential=False, use_vocab=False)
    # train_stance_iter, dev_stance_iter, test_stance_iter = mydatasets_cn.stance_datase, label_field, args)
    # train_iter, dev_iter = mydatasets_cn.ir_datase, args)

    label_list = ["0", "1"]
    num_labels = len(label_list)
    stance_dataset = Stance_Dataset(args, "train.tsv", "dev.tsv", "test.tsv")
    ir_dataset = IR_Dataset(args, "train.tsv", "dev.tsv")
    # '''

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_test:
        raise ValueError(
            "At least one of do_train` or `do_eval` or `do_test` must be True."
        )

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    output_mode = output_modes[task_name]

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    num_train_optimization_steps = None
    if args.do_train:
        # '''
        num_train_optimization_steps = int(
            len(stance_dataset.train_data) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs

        # '''
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(stance_dataset.train_data))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            model.train()

            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0

            # for step, batch_stance in enumerate(tqdm(train_stance_iter, desc="Iteration")):
            train_stance_iter, dev_stance_iter, test_stance_iter = stance_dataset.Iterator(
            )
            train_iter, dev_iter = ir_dataset.Iterator()
            for step, (batch, batch_stance) in enumerate(
                    tqdm(zip(train_iter, train_stance_iter),
                         desc="Iteration")):
                input_ids, input_mask, segment_ids, label_ids = to_bert_input_related(
                    args, batch, label_list, tokenizer, output_mode)
                input_ids_stance, input_mask_stance, segment_ids_stance, label_ids_stance = to_bert_input_stance(
                    args, batch_stance, label_list, tokenizer, output_mode)

                input_ids = torch.cat((input_ids, input_ids_stance), dim=0)
                input_mask = torch.cat((input_mask, input_mask_stance), dim=0)
                segment_ids = torch.cat((segment_ids, segment_ids_stance),
                                        dim=0)
                label_ids = torch.cat((label_ids, label_ids_stance), dim=0)

                # input_ids, input_mask, segment_ids, label_ids = to_bert_input_stance(args, batch_stance, label_list, tokenizer, output_mode)
                input_ids, input_mask, segment_ids, label_ids = input_ids.to(
                    device), input_mask.to(device), segment_ids.to(
                        device), label_ids.to(device)

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels),
                                    label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model and the associated configuration
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self
            output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
            torch.save(model_to_save.state_dict(), output_model_file)
            output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
            with open(output_config_file, 'w') as f:
                f.write(model_to_save.config.to_json_string())

            # eval
            logger.info("***** Running evaluation *****")

            model.eval()
            eval_loss = 0
            nb_eval_steps = 0
            preds = []
            all_label_ids = []

            # for step, batch_stance in enumerate(tqdm(dev_stance_iter, desc="Evaluating")):
            for step, (batch, batch_stance) in enumerate(
                    tqdm(zip(dev_iter, dev_stance_iter), desc="Evaluating")):
                input_ids, input_mask, segment_ids, label_ids = to_bert_input_related(
                    args, batch, label_list, tokenizer, output_mode)
                input_ids_stance, input_mask_stance, segment_ids_stance, label_ids_stance = to_bert_input_stance(
                    args, batch_stance, label_list, tokenizer, output_mode)

                input_ids = torch.cat((input_ids, input_ids_stance), dim=0)
                input_mask = torch.cat((input_mask, input_mask_stance), dim=0)
                segment_ids = torch.cat((segment_ids, segment_ids_stance),
                                        dim=0)
                label_ids = torch.cat((label_ids, label_ids_stance), dim=0)

                # input_ids, input_mask, segment_ids, label_ids = to_bert_input_stance(args, batch_stance, label_list, tokenizer, output_mode)
                all_label_ids += list(label_ids.numpy())
                input_ids, input_mask, segment_ids, label_ids = input_ids.to(
                    device), input_mask.to(device), segment_ids.to(
                        device), label_ids.to(device)

                with torch.no_grad():
                    logits = model(input_ids,
                                   segment_ids,
                                   input_mask,
                                   labels=None)

                # create eval loss and other metric required by the task
                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                             label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    tmp_eval_loss = loss_fct(logits.view(-1),
                                             label_ids.view(-1))

                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_steps += 1
                if len(preds) == 0:
                    preds.append(logits.detach().cpu().numpy())
                else:
                    preds[0] = np.append(preds[0],
                                         logits.detach().cpu().numpy(),
                                         axis=0)

            eval_loss = eval_loss / nb_eval_steps
            preds = preds[0]

            exp_x = np.exp(preds)
            softmax_x = exp_x / np.mat(np.sum(exp_x, axis=1)).T
            # np.save("eval_logits.npy", softmax_x)
            # np.save("eval_labels.npy", all_label_ids)

            if output_mode == "classification":
                preds = np.argmax(preds, axis=1)
            elif output_mode == "regression":
                preds = np.squeeze(preds)
            result = compute_metrics(task_name, preds, all_label_ids)
            loss = tr_loss / nb_tr_steps if args.do_train else None

            result['eval_loss'] = eval_loss
            result['global_step'] = global_step
            result['loss'] = loss

            output_eval_file = os.path.join(args.output_dir,
                                            "eval_results.txt")
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_test and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        logger.info("***** Running testing *****")

        # Load a trained model and config that you have fine-tuned
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
        if args.fp16:
            model.half()
        model.to(device)

        model.eval()

        train_stance_iter, dev_stance_iter, test_stance_iter = stance_dataset.Iterator(
        )
        for i, batch in enumerate(tqdm(test_stance_iter, desc="Evaluating")):
            test_src_batch, test_tgt_batch = batch

            input_ids, input_mask, segment_ids = to_bert_input_test(
                args, test_src_batch, test_tgt_batch, label_list, tokenizer,
                output_mode)
            input_ids, input_mask, segment_ids = input_ids.to(
                device), input_mask.to(device), segment_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            if i == 0:
                pred_poss = list(
                    F.softmax(logits, dim=1).detach().cpu().numpy())
            else:
                pred_poss += list(
                    F.softmax(logits, dim=1).detach().cpu().numpy())

        pred_poss = np.array(pred_poss)
        np.save("preds_stance.npy", pred_poss)
        print(pred_poss.shape)
def main():
    logger.info("Running %s" % ' '.join(sys.argv))

    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--scan",
                        default="horizontal",
                        choices=["vertical", "horizontal"],
                        type=str,
                        help="The direction of linearizing table cells.")
    parser.add_argument(
        "--data_dir",
        default="../processed_datasets",
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--output_dir",
        default="outputs",
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument(
        "--load_dir",
        type=str,
        help=
        "The output directory where the model checkpoints will be loaded during evaluation"
    )
    parser.add_argument('--load_step',
                        type=int,
                        default=0,
                        help="The checkpoint step to be loaded")
    parser.add_argument("--fact",
                        default="first",
                        choices=["first", "second"],
                        type=str,
                        help="Whether to put fact in front.")
    parser.add_argument(
        "--test_set",
        default="dev",
        choices=["dev", "test", "simple_test", "complex_test", "small_test"],
        help="Which test set is used for evaluation",
        type=str)
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--balance",
                        action='store_true',
                        help="balance between + and - samples for training.")
    ## Other parameters
    parser.add_argument(
        "--bert_model",
        default="bert-base-multilingual-cased",
        type=str,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default="QQP",
                        type=str,
                        help="The name of the task to train.")
    parser.add_argument('--period', type=int, default=500)
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=512,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=6,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=20.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()
    pprint(vars(args))
    sys.stdout.flush()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        "qqp": QqpProcessor,
    }

    output_modes = {
        "qqp": "classification",
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    args.output_dir = "{}_fact-{}_{}".format(args.output_dir, args.fact,
                                             args.scan)
    args.data_dir = os.path.join(args.data_dir,
                                 "tsv_data_{}".format(args.scan))
    logger.info(
        "Datasets are loaded from {}\n Outputs will be saved to {}".format(
            args.data_dir, args.output_dir))
    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)
    writer = SummaryWriter(os.path.join(args.output_dir, 'events'))

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = RobertaTokenizer.from_pretrained("roberta-base")

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    if args.load_dir:
        load_dir = args.load_dir
    else:
        load_dir = args.bert_model

    model = BertForSequenceClassification.from_pretrained(
        load_dir, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    if args.do_train:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        if args.fp16:
            try:
                from apex.optimizers import FP16_Optimizer
                from apex.optimizers import FusedAdam
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            optimizer = FusedAdam(optimizer_grouped_parameters,
                                  lr=args.learning_rate,
                                  bias_correction=False,
                                  max_grad_norm=1.0)
            if args.loss_scale == 0:
                optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
            else:
                optimizer = FP16_Optimizer(optimizer,
                                           static_loss_scale=args.loss_scale)
            warmup_linear = WarmupLinearSchedule(
                warmup=args.warmup_proportion,
                t_total=num_train_optimization_steps)

        else:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

    global_step = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer,
                                                      output_mode,
                                                      fact_place=args.fact,
                                                      balance=args.balance)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            logger.info("Training epoch {} ...".format(epoch))
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels),
                                    label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                writer.add_scalar('train/loss', loss, global_step)
                tr_loss += loss.item()

                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    total_norm = 0.0
                    for n, p in model.named_parameters():
                        if p.grad is not None:
                            param_norm = p.grad.data.norm(2)
                            total_norm += param_norm.item()**2
                    total_norm = total_norm**(1. / 2)
                    preds = torch.argmax(logits, -1) == label_ids
                    acc = torch.sum(preds).float() / preds.size(0)
                    writer.add_scalar('train/gradient_norm', total_norm,
                                      global_step)
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    model.zero_grad()
                    global_step += 1

                if (step + 1) % args.period == 0:
                    # Save a trained model, configuration and tokenizer
                    model_to_save = model.module if hasattr(
                        model,
                        'module') else model  # Only save the model it-self

                    # If we save using the predefined names, we can load using `from_pretrained`

                    output_dir = os.path.join(
                        args.output_dir, 'save_step_{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)

                    output_model_file = os.path.join(output_dir, RobertaModel)
                    output_config_file = os.path.join(output_dir,
                                                      RobertaConfig)

                    torch.save(model_to_save.state_dict(), output_model_file)
                    model_to_save.config.to_json_file(output_config_file)
                    tokenizer.save_vocabulary(output_dir)

                    model.eval()
                    torch.set_grad_enabled(False)  # turn off gradient tracking
                    evaluate(args,
                             model,
                             device,
                             processor,
                             label_list,
                             num_labels,
                             tokenizer,
                             output_mode,
                             tr_loss,
                             global_step,
                             task_name,
                             tbwriter=writer,
                             save_dir=output_dir)
                    model.train()  # turn on train mode
                    torch.set_grad_enabled(True)  # start gradient tracking
                    tr_loss = 0

    # do eval before exit
    if args.do_eval:
        if not args.do_train:
            global_step = 0
            output_dir = None
        save_dir = args.load_dir
        tbwriter = SummaryWriter(os.path.join(save_dir, 'eval/events'))
        load_step = args.load_step
        if args.load_dir is not None:
            load_step = int(
                os.path.split(args.load_dir)[1].replace('save_step_', ''))
            print("load_step = {}".format(load_step))
        model.eval()
        evaluate(args,
                 model,
                 device,
                 processor,
                 label_list,
                 num_labels,
                 tokenizer,
                 output_mode,
                 tr_loss,
                 global_step,
                 task_name,
                 tbwriter=tbwriter,
                 save_dir=save_dir,
                 load_step=load_step)
class Distiller:
    def __init__(self, params: dict, dataloader: Dataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.dataloader = dataloader
        if self.params.n_gpu > 1:
            self.dataloader.split()
        self.get_iterator(seed=params.seed)

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_mse = params.alpha_mse
        assert self.alpha_ce >= 0.
        assert self.alpha_mlm >= 0.
        assert self.alpha_mse >= 0.
        assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0.

        self.mlm_mask_prop = params.mlm_mask_prop
        assert 0.0 <= self.mlm_mask_prop <= 1.0
        assert params.word_mask + params.word_keep + params.word_rand == 1.0
        self.pred_probs = torch.FloatTensor(
            [params.word_mask, params.word_keep, params.word_rand])
        self.pred_probs = self.pred_probs.to(
            f'cuda:{params.local_rank}'
        ) if params.n_gpu > 0 else self.pred_probs
        self.token_probs = token_probs.to(
            f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
        if self.fp16:
            self.pred_probs = self.pred_probs.half()
            self.token_probs = self.token_probs.half()

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_mse = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
        self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        self.mse_loss_fct = nn.MSELoss(reduction='sum')

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = int(
            len(self.dataloader) / params.batch_size) + 1
        num_train_optimization_steps = int(
            self.num_steps_epoch / params.gradient_accumulation_steps *
            params.n_epoch) + 1
        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in student.named_parameters()
                if not any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            params.weight_decay
        }, {
            'params': [
                p for n, p in student.named_parameters()
                if any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            0.0
        }]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))
        self.scheduler = WarmupLinearSchedule(
            self.optimizer,
            warmup_steps=warmup_steps,
            t_total=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank)

        self.is_master = params.is_master
        if self.is_master:
            logger.info('--- Initializing Tensorboard')
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, 'log', 'train'))
            self.tensorboard.add_text(tag='config',
                                      text_string=str(self.params),
                                      global_step=0)

    def get_iterator(self, seed: int = None):
        """
        Initialize the data iterator.
        Each process has its own data iterator (iterating on his own random portion of the dataset).

        Input:
        ------
            seed: `int` - The random seed.
        """
        logger.info('--- Initializing Data Iterator')
        self.data_iterator = self.dataloader.get_iterator(seed=seed)

    def get_batch(self):
        """
        Call the data iterator to output a new batch.
        If the data iterator went through the whole dataset, create a new iterator.
        """
        assert hasattr(self, 'data_iterator')
        try:
            x = next(self.data_iterator)
        except StopIteration:
            logger.warning(
                '--- Went through the whole dataset. Creating new data iterator.'
            )
            self.data_iterator = self.dataloader.get_iterator()
            x = next(self.data_iterator)
        return x

    def prepare_batch(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1),
                                  dtype=torch.long,
                                  device=lengths.device) < lengths[:, None])

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(),
                                    n_tgt,
                                    replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(
            self.params.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(
            self.params.special_tok_ids['mask_token'])
        probs = torch.multinomial(self.pred_probs,
                                  len(_token_ids_real),
                                  replacement=True)
        _token_ids = _token_ids_mask * (
            probs == 0).long() + _token_ids_real * (
                probs == 1).long() + _token_ids_rand * (probs == 2).long()
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[
            ~pred_mask] = -1  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        return token_ids, attn_mask, mlm_labels

    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            pad_id = self.params.special_tok_ids['pad_token']
            padding_tensor = torch.zeros(bs2,
                                         pad,
                                         dtype=torch.long,
                                         device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master: logger.info('Starting training')
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(
                    f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')

            iter_bar = trange(self.num_steps_epoch,
                              desc="-Iter",
                              disable=self.params.local_rank not in [-1, 0])
            for __ in range(self.num_steps_epoch):
                batch = self.get_batch()
                if self.params.n_gpu > 0:
                    batch = tuple(
                        t.to(f'cuda:{self.params.local_rank}') for t in batch)
                token_ids, attn_mask, mlm_labels = self.prepare_batch(
                    batch=batch)

                self.step(input_ids=token_ids,
                          attention_mask=attn_mask,
                          mlm_labels=mlm_labels)

                iter_bar.update()
                iter_bar.set_postfix({
                    'Last_loss':
                    f'{self.last_loss:.2f}',
                    'Avg_cum_loss':
                    f'{self.total_loss_epoch/self.n_iter:.2f}'
                })
            iter_bar.close()

            if self.is_master:
                logger.info(
                    f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

        if self.is_master:
            logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
            self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
            logger.info('Training is finished')

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             mlm_labels: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
        """
        s_logits = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask)[0]  # (bs, seq_length, voc_size)
        with torch.no_grad():
            t_logits = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask)[0]  # (bs, seq_length, voc_size)
        assert s_logits.size() == t_logits.size()

        #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
            mask = (mlm_labels > -1).unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(
            s_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(
            t_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature,
                      dim=-1)) * (self.temperature)**2
        loss = self.alpha_ce * loss_ce
        if self.alpha_mlm > 0.:
            loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)),
                                         mlm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_mse > 0.:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_mse > 0.:
            self.last_loss_mse = loss_mse.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error('NaN detected')
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer),
                    self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                               self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag='parameter_mean/' + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag='parameter_std/' + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch",
                                    scalar_value=self.total_loss_epoch /
                                    self.n_iter,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mlm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mlm",
                                        scalar_value=self.last_loss_mlm,
                                        global_step=self.n_total_iter)
        if self.alpha_mse > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()['used'] / 1_000_000,
            global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f'{self.n_sequences_epoch} sequences have been trained during this epoch.'
        )

        if self.is_master:
            self.save_checkpoint(
                checkpoint_name=f'model_epoch_{self.epoch}.pth')
            self.tensorboard.add_scalar(tag='epoch/loss',
                                        scalar_value=self.total_loss_epoch /
                                        self.n_iter,
                                        global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = 'checkpoint.pth'):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(
            self.student, 'module') else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
Example #14
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        default=False,
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=16,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=1.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        default=False,
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--init_fs_path',
                        type=str, default="./",
                        help="File system path for distributed training with init_method using shared file system \n"
                             "./ (default value): current directory\n")
    parser.add_argument('--node_rank',
                        type=int, default=0,
                        help="Rank of current node \n"
                             "0 (default value): rank of current node\n")
    parser.add_argument('--nodes_count',
                        type=int, default=1,
                        help="Number of nodes to determine the world size for distributed training \n"
                             "1 (default value): count of nodes\n")

    args = parser.parse_args()

    logger.info("world_size: {}, node rank: {}, local rank {}".format(args.nodes_count, args.node_rank, args.local_rank))
    logger.info("Process is being blocked until all nodes are ready.")

    torch.cuda.set_device(args.local_rank)
    device = torch.device("cuda", args.local_rank)
    n_gpu = 1
    # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
    torch.distributed.init_process_group(backend='nccl',
                                         init_method='file://'+args.init_fs_path,
                                         rank = args.node_rank,
                                         world_size = args.nodes_count
                                        )

    logger.info("All nodes ready, unblocking process.\n\n")
    logger.info("Global rank: {}, device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        torch.distributed.get_rank(), device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    vecs = []
    vecs.append([0] * 100)  # CLS
    with open("kg_embed/entity2vec.vec", 'r') as fin:
        for line in fin:
            vec = line.strip().split('\t')
            vec = [float(x) for x in vec]
            vecs.append(vec)
    embed = torch.FloatTensor(vecs)
    embed = torch.nn.Embedding.from_pretrained(embed)
    # embed = torch.nn.Embedding(5041175, 100)

    logger.info("Shape of entity embedding: " + str(embed.weight.size()))
    del vecs

    train_data = None
    num_train_steps = None
    if args.do_train:
        # TODO
        import indexed_dataset
        from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler, BatchSampler
        import iterators
        # train_data = indexed_dataset.IndexedCachedDataset(args.data_dir)
        train_data = indexed_dataset.IndexedDataset(args.data_dir, fix_lua_indexing=True)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_sampler = BatchSampler(train_sampler, args.train_batch_size, True)

        def collate_fn(x):
            x = torch.LongTensor([xx for xx in x])

            entity_idx = x[:, 4 * args.max_seq_length:5 * args.max_seq_length]
            # Build candidate
            uniq_idx = np.unique(entity_idx.numpy())
            ent_candidate = embed(torch.LongTensor(uniq_idx + 1))
            ent_candidate = ent_candidate.repeat([n_gpu, 1])
            # build entity labels
            d = {}
            dd = []
            for i, idx in enumerate(uniq_idx):
                d[idx] = i
                dd.append(idx)
            ent_size = len(uniq_idx) - 1

            def map(x):
                if x == -1:
                    return -1
                else:
                    rnd = random.uniform(0, 1)
                    if rnd < 0.05:
                        return dd[random.randint(1, ent_size)]
                    elif rnd < 0.2:
                        return -1
                    else:
                        return x

            ent_labels = entity_idx.clone()
            d[-1] = -1
            ent_labels = ent_labels.apply_(lambda x: d[x])

            entity_idx.apply_(map)
            ent_emb = embed(entity_idx + 1)
            mask = entity_idx.clone()
            mask.apply_(lambda x: 0 if x == -1 else 1)
            mask[:, 0] = 1

            return x[:, :args.max_seq_length], x[:, args.max_seq_length:2 * args.max_seq_length], x[:, 2 * args.max_seq_length:3 * args.max_seq_length], x[:,3 * args.max_seq_length:4 * args.max_seq_length], ent_emb, mask, x[:,6 * args.max_seq_length:], ent_candidate, ent_labels

        train_iterator = iterators.EpochBatchIterator(train_data, collate_fn, train_sampler)
        num_train_steps = int(
            len(train_data) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    model, missing_keys = BertForPreTraining.from_pretrained(args.bert_model,cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))

    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_linear = ['layer.2.output.dense_ent', 'layer.2.intermediate.dense_1',
                 'bert.encoder.layer.2.intermediate.dense_1_ent', 'layer.2.output.LayerNorm_ent']
    no_linear = [x.replace('2', '11') for x in no_linear]
    param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in no_linear)]
    # param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in missing_keys)]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm_ent.bias', 'LayerNorm_ent.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.contrib.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False)
        if args.loss_scale == 0:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2")
            # optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            model, optimizer = amp.initialize(model, optimizer, opt_level="O2", loss_scale=args.loss_scale)
        # logger.info(dir(optimizer))
        # op_path = os.path.join(args.bert_model, "pytorch_op.bin")
        # optimizer.load_state_dict(torch.load(op_path))

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)
        model.train()
        import datetime
        fout = open(os.path.join(args.output_dir, "loss.{}".format(datetime.datetime.now())), 'w')
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_iterator.next_epoch_itr(), desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, masked_lm_labels, input_ent, ent_mask, next_sentence_label, ent_candidate, ent_labels = batch
                loss, original_loss = model(input_ids, segment_ids, input_mask, masked_lm_labels, input_ent, ent_mask,
                                            next_sentence_label, ent_candidate, ent_labels)

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                    original_loss = original_loss.mean()
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    try:
                        from apex import amp
                    except ImportError:
                        raise ImportError(
                            "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                else:
                    loss.backward()

                fout.write("{} {}\n".format(loss.item() * args.gradient_accumulation_steps, original_loss.item()))
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                        optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                    # if global_step % 1000 == 0:
                    #    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                    #    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
                    #    torch.save(model_to_save.state_dict(), output_model_file)
        fout.close()

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model_" + str(torch.distributed.get_rank()) + str(
        args.local_rank) + ".bin")
    torch.save(model_to_save.state_dict(), output_model_file)
def main():
    parser = argparse.ArgumentParser()

    # Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument("--vocab_file",
                        default='bert-base-uncased-vocab.txt',
                        type=str,
                        required=True)
    parser.add_argument("--model_file",
                        default='bert-base-uncased.tar.gz',
                        type=str,
                        required=True)
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model checkpoints and predictions will be written."
    )
    parser.add_argument(
        "--predict_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the predictions will be written.")

    # Other parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        help="SQuAD json for training. E.g., train-v1.1.json")
    parser.add_argument(
        "--predict_file",
        default=None,
        type=str,
        help="SQuAD json for predictions. E.g., dev-v1.1.json or test-v1.1.json"
    )
    parser.add_argument("--test_file", default=None, type=str)
    parser.add_argument(
        "--max_seq_length",
        default=384,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. Sequences "
        "longer than this will be truncated, and sequences shorter than this will be padded."
    )
    parser.add_argument(
        "--doc_stride",
        default=128,
        type=int,
        help=
        "When splitting up a long document into chunks, how much stride to take between chunks."
    )
    parser.add_argument(
        "--max_query_length",
        default=64,
        type=int,
        help=
        "The maximum number of tokens for the question. Questions longer than this will "
        "be truncated to this length.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_predict",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--predict_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for predictions.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. E.g., 0.1 = 10% "
        "of training.")
    parser.add_argument(
        "--n_best_size",
        default=20,
        type=int,
        help=
        "The total number of n-best predictions to generate in the nbest_predictions.json "
        "output file.")
    parser.add_argument(
        "--max_answer_length",
        default=30,
        type=int,
        help=
        "The maximum length of an answer that can be generated. This is needed because the start "
        "and end predictions are not conditioned on one another.")
    parser.add_argument(
        "--verbose_logging",
        default=False,
        action='store_true',
        help=
        "If true, all of the warnings related to data processing will be printed. "
        "A number of warnings are expected for a normal SQuAD evaluation.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--view_id',
                        type=int,
                        default=1,
                        help="view id of multi-view co-training(two-view)")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        "--do_lower_case",
        default=True,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    # Base setting
    parser.add_argument('--pretrain', type=str, default=None)
    parser.add_argument('--max_ctx', type=int, default=2)
    parser.add_argument('--task_name', type=str, default='coqa_yesno')
    parser.add_argument('--bert_name', type=str, default='baseline')
    parser.add_argument('--reader_name', type=str, default='coqa')
    parser.add_argument('--per_eval_step', type=int, default=10000000)
    # model parameters
    parser.add_argument('--evidence_lambda', type=float, default=0.8)
    parser.add_argument('--tf_layers', type=int, default=1)
    parser.add_argument('--tf_inter_size', type=int, default=3072)
    # Parameters for running labeling model
    parser.add_argument('--do_label', default=False, action='store_true')
    parser.add_argument('--sentence_id_file', type=str, default=None)
    parser.add_argument('--weight_threshold', type=float, default=0.0)
    parser.add_argument('--only_correct', default=False, action='store_true')
    parser.add_argument('--label_threshold', type=float, default=0.0)
    parser.add_argument('--use_gumbel', default=False, action='store_true')
    parser.add_argument('--sample_steps', type=int, default=10)
    parser.add_argument('--reward_func', type=int, default=0)
    parser.add_argument('--freeze_bert', default=False, action='store_true')
    parser.add_argument('--num_evidence', default=1, type=int)
    parser.add_argument('--power_length', default=1., type=float)
    parser.add_argument('--split_type', default=0, type=int)
    parser.add_argument('--remove_evidence',
                        default=False,
                        action='store_true')
    parser.add_argument('--remove_question',
                        default=False,
                        action='store_true')
    parser.add_argument('--remove_passage', default=False, action='store_true')
    parser.add_argument('--remove_dict', default=None, type=str)
    parser.add_argument('--freeze_predictor', default=None, type=str)
    parser.add_argument('--evidence_search_file', default=None, type=str)

    args = parser.parse_args()

    logger = setting_logger(args.output_dir)
    logger.info('================== Program start. ========================')
    logger.info(f'================== Set seed {args.seed} =================')

    model_params = prepare_model_params(args)
    read_params = prepare_read_params(args)

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict and not args.do_label:
        raise ValueError(
            "At least one of `do_train` or `do_predict` or `do_label` must be True."
        )

    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )

    if args.do_train:
        if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
            raise ValueError(
                "Output directory () already exists and is not empty.")
        os.makedirs(args.output_dir, exist_ok=True)

    if args.do_predict or args.do_label:
        os.makedirs(args.predict_dir, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.vocab_file)

    data_reader = initialize_reader(args.reader_name)

    num_train_steps = None
    if args.do_train or args.do_label:
        train_examples = data_reader.read(input_file=args.train_file,
                                          **read_params)

        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}_{4}_{5}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length), str(args.max_ctx), str(args.task_name))

        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except FileNotFoundError:
            train_features = data_reader.convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                logger.info("  Saving train features into cached file %s",
                            cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)

        num_train_steps = int(
            len(train_features) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    if args.pretrain is not None:
        logger.info('Load pretrained model from {}'.format(args.pretrain))
        model_state_dict = torch.load(args.pretrain, map_location='cuda:0')
        model = initialize_model(args.bert_name,
                                 args.model_file,
                                 state_dict=model_state_dict,
                                 **model_params)
    else:
        model = initialize_model(args.bert_name, args.model_file,
                                 **model_params)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # Remove frozen parameters
    param_optimizer = [n for n in param_optimizer if n[1].requires_grad]

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    t_total = num_train_steps if num_train_steps else -1
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
        warmup_linear = WarmupLinearSchedule(warmup=args.warmup_proportion,
                                             t_total=t_total)
        logger.info(
            f"warm up linear: warmup = {warmup_linear.warmup}, t_total = {warmup_linear.t_total}."
        )
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    # Prepare data
    eval_examples = data_reader.read(input_file=args.predict_file,
                                     **read_params)
    eval_features = data_reader.convert_examples_to_features(
        examples=eval_examples,
        tokenizer=tokenizer,
        max_seq_length=args.max_seq_length,
        doc_stride=args.doc_stride,
        max_query_length=args.max_query_length)

    eval_tensors = data_reader.data_to_tensors(eval_features)
    eval_data = TensorDataset(*eval_tensors)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.predict_batch_size)

    if args.do_train:

        if args.do_label:
            logger.info('Training in State Wise.')
            sentence_id_file = args.sentence_id_file
            if sentence_id_file is not None:
                # for file in sentence_id_file_list:
                #     train_features = data_reader.generate_features_sentence_ids(train_features, file)
                train_features = data_reader.generate_features_sentence_ids(
                    train_features, sentence_id_file)
            else:
                # train_features = data_reader.mask_all_sentence_ids(train_features)
                logger.info('No sentence id supervision is found.')
        else:
            logger.info('Training in traditional way.')

        logger.info("***** Running training *****")
        logger.info("  Num orig examples = %d", len(train_examples))
        logger.info("  Num split examples = %d", len(train_features))
        logger.info("  Batch size = %d", args.train_batch_size)
        train_loss = AverageMeter()
        summary_writer = SummaryWriter(log_dir=args.output_dir)
        global_step = 0
        eval_loss = AverageMeter()
        best_metric = 0.0
        eval_epoch = 0
        eval_acc = CategoricalAccuracy()

        train_tensors = data_reader.data_to_tensors(train_features)
        train_data = TensorDataset(*train_tensors)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        for epoch in range(int(args.num_train_epochs)):
            logger.info(f'Running at Epoch {epoch}')
            # Train
            for step, batch in enumerate(
                    tqdm(train_dataloader,
                         desc="Iteration",
                         dynamic_ncols=True)):
                model.train()
                if n_gpu == 1:
                    batch = batch_to_device(
                        batch, device)  # multi-gpu does scattering it-self
                inputs = data_reader.generate_inputs(
                    batch, train_features, model_state=ModelState.Train)
                loss = model(**inputs)['loss']
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used and handles this automatically
                    if args.fp16:
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                        summary_writer.add_scalar('lr', lr_this_step,
                                                  global_step)
                    else:
                        summary_writer.add_scalar('lr',
                                                  optimizer.get_lr()[0],
                                                  global_step)

                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                    train_loss.update(loss.item(), 1)
                    summary_writer.add_scalar('train_loss', train_loss.avg,
                                              global_step)

                if (step + 1) % args.per_eval_step == 0 or step == len(
                        train_dataloader) - 1:
                    # Evaluation
                    model.eval()
                    all_results = []
                    logger.info("Start evaluating")
                    for _, eval_batch in enumerate(
                            tqdm(eval_dataloader,
                                 desc="Evaluating",
                                 dynamic_ncols=True)):
                        if n_gpu == 1:
                            eval_batch = batch_to_device(
                                eval_batch,
                                device)  # multi-gpu does scattering it-self
                        inputs = data_reader.generate_inputs(
                            eval_batch,
                            eval_features,
                            model_state=ModelState.Evaluate)
                        with torch.no_grad():
                            output_dict = model(**inputs)
                            loss, batch_choice_logits = output_dict[
                                'loss'], output_dict['yesno_logits']
                            eval_acc(batch_choice_logits,
                                     inputs["answer_choice"])
                            eval_loss.update(loss.item(), 1)

                        example_indices = eval_batch[-1]
                        for i, example_index in enumerate(example_indices):
                            choice_logits = batch_choice_logits[i].detach(
                            ).cpu().tolist()

                            eval_feature = eval_features[example_index.item()]
                            unique_id = int(eval_feature.unique_id)
                            # print(unique_id)
                            all_results.append(
                                RawResultChoice(unique_id=unique_id,
                                                choice_logits=choice_logits))

                    eval_epoch_loss = eval_loss.avg
                    summary_writer.add_scalar('eval_loss', eval_epoch_loss,
                                              eval_epoch)
                    eval_loss.reset()

                    _, metric, save_metric = data_reader.write_predictions(
                        eval_examples, eval_features, all_results, None)
                    logger.info(f"Eval epoch: {eval_epoch}")
                    for k, v in metric.items():
                        logger.info(f"{k}: {v}")
                        summary_writer.add_scalar(f'eval_{k}', v, eval_epoch)
                    print(f"Eval accuracy: {eval_acc.get_metric(reset=True)}")
                    torch.cuda.empty_cache()

                    if save_metric[1] > best_metric:
                        best_metric = save_metric[1]
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                    logger.info('Eval Epoch: %d, %s: %f (Best %s: %f)' %
                                (eval_epoch, save_metric[0], save_metric[1],
                                 save_metric[0], best_metric))
                    eval_epoch += 1

        summary_writer.close()

    # Loading trained model.
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    model_state_dict = torch.load(output_model_file, map_location='cuda:0')
    model = initialize_model(args.bert_name,
                             args.model_file,
                             state_dict=model_state_dict,
                             **model_params)
    model.to(device)

    # Write Yes/No predictions
    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):

        test_examples = eval_examples
        test_features = eval_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running predictions *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start predicting yes/no on Dev set.")
        for batch in tqdm(test_dataloader, desc="Testing", dynamic_ncols=True):
            if n_gpu == 1:
                batch = batch_to_device(
                    batch, device)  # multi-gpu does scattering it-self
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                batch_choice_logits = model(**inputs)['yesno_logits']
            example_indices = batch[-1]
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu().tolist()

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawResultChoice(unique_id=unique_id,
                                    choice_logits=choice_logits))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'predictions.json')
        _, metric, _ = data_reader.write_predictions(eval_examples,
                                                     eval_features,
                                                     all_results,
                                                     output_prediction_file)
        for k, v in metric.items():
            logger.info(f'{k}: {v}')

    # Labeling sentence id.
    if args.do_label and (args.local_rank == -1
                          or torch.distributed.get_rank() == 0):

        # f = open('debug_log.txt', 'w')

        def softmax(x):
            """Compute softmax values for each sets of scores in x."""
            e_x = np.exp(x - np.max(x))
            return e_x / e_x.sum()

        def topk(sentence_sim):
            """
            :param sentence_sim: numpy
            :return:
            """
            max_length = min(args.num_evidence, len(sentence_sim))
            sorted_scores = np.array(sorted(sentence_sim, reverse=True))
            scores = []
            for idx in range(max_length):
                scores.append(np.log(softmax(sorted_scores[idx:])[0]))
            scores = [np.mean(scores[:(j + 1)]) for j in range(max_length)]
            top_k = int(np.argmax(scores) + 1)
            sorted_scores = sorted(enumerate(sentence_sim),
                                   key=lambda x: x[1],
                                   reverse=True)
            evidence_ids = [x[0] for x in sorted_scores[:top_k]]
            sentence = {
                'sentences': evidence_ids,
                'value': float(np.exp(scores[top_k - 1]))
            }
            # print(f'value = {sentence["value"]}', file=f, flush=True)
            return sentence

        def batch_topk(sentence_sim, sentence_mask):
            sentence_sim = sentence_sim.squeeze(
                1).detach().cpu().numpy() + 1e-15
            sentence_mask = sentence_mask.detach().cpu().numpy()
            sentence_ids = [
                topk(_sim[:int(sum(_mask))])
                for _sim, _mask in zip(sentence_sim, sentence_mask)
            ]
            # print('=' * 20, file=f, flush=True)
            return sentence_ids

        test_examples = train_examples
        test_features = train_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running labeling *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start labeling.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(batch, device)
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits = output_dict['yesno_logits']
                batch_beam_result = batch_topk(output_dict['sentence_sim'],
                                               output_dict['sentence_mask'])
            example_indices = batch[-1]
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu().tolist()
                # max_weight_index = batch_max_weight_indexes[i].detach().cpu().tolist()
                # max_weight = batch_max_weight[i].detach().cpu().tolist()
                evidence = batch_beam_result[i]

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawOutput(unique_id=unique_id,
                              model_output={
                                  "choice_logits": choice_logits,
                                  "evidence": evidence
                              }))
                # all_results.append(
                #     WeightResultChoice(unique_id=unique_id, choice_logits=choice_logits, max_weight_index=max_weight_index,
                #                        max_weight=max_weight))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'sentence_id_file.json')
        data_reader.predict_sentence_ids(
            test_examples,
            test_features,
            all_results,
            output_prediction_file,
            weight_threshold=args.weight_threshold,
            only_correct=args.only_correct,
            label_threshold=args.label_threshold)

        test_examples = eval_examples
        test_features = eval_features

        test_tensors = data_reader.data_to_tensors(test_features)
        test_data = TensorDataset(*test_tensors)
        test_sampler = SequentialSampler(test_data)
        test_dataloader = DataLoader(test_data,
                                     sampler=test_sampler,
                                     batch_size=args.predict_batch_size)

        logger.info("***** Running labeling *****")
        logger.info("  Num orig examples = %d", len(test_examples))
        logger.info("  Num split examples = %d", len(test_features))
        logger.info("  Batch size = %d", args.predict_batch_size)

        model.eval()
        all_results = []
        logger.info("Start labeling.")
        for batch in tqdm(test_dataloader, desc="Testing"):
            if n_gpu == 1:
                batch = batch_to_device(batch, device)
            inputs = data_reader.generate_inputs(batch,
                                                 test_features,
                                                 model_state=ModelState.Test)
            with torch.no_grad():
                output_dict = model(**inputs)
                batch_choice_logits = output_dict['yesno_logits']
                batch_beam_result = batch_topk(output_dict['sentence_sim'],
                                               output_dict['sentence_mask'])
            example_indices = batch[-1]
            for i, example_index in enumerate(example_indices):
                choice_logits = batch_choice_logits[i].detach().cpu().tolist()
                # max_weight_index = batch_max_weight_indexes[i].detach().cpu().tolist()
                # max_weight = batch_max_weight[i].detach().cpu().tolist()
                evidence = batch_beam_result[i]

                test_feature = test_features[example_index.item()]
                unique_id = int(test_feature.unique_id)

                all_results.append(
                    RawOutput(unique_id=unique_id,
                              model_output={
                                  "choice_logits": choice_logits,
                                  "evidence": evidence
                              }))
                # all_results.append(
                #     WeightResultChoice(unique_id=unique_id, choice_logits=choice_logits, max_weight_index=max_weight_index,
                #                        max_weight=max_weight))

        output_prediction_file = os.path.join(args.predict_dir,
                                              'dev_sentence_id_file.json')
        data_reader.predict_sentence_ids(
            test_examples,
            test_features,
            all_results,
            output_prediction_file,
            weight_threshold=args.weight_threshold,
            only_correct=args.only_correct,
            label_threshold=args.label_threshold)
Example #16
0
class Distiller:
    def __init__(self, params: dict, dataset: LmSeqsDataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info("Initializing Distiller")
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        if params.group_by_size:
            groups = create_lengths_groups(lengths=dataset.lengths,
                                           k=params.max_model_input_size)
            sampler = GroupedBatchSampler(sampler=sampler,
                                          group_ids=groups,
                                          batch_size=params.batch_size)
        else:
            sampler = BatchSampler(sampler=sampler,
                                   batch_size=params.batch_size,
                                   drop_last=False)

        self.dataloader = DataLoader(dataset=dataset,
                                     batch_sampler=sampler,
                                     collate_fn=dataset.batch_sequences)

        self.temperature = params.temperature
        assert self.temperature > 0.0

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_clm = params.alpha_clm
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        self.mlm = params.mlm
        if self.mlm:
            logger.info("Using MLM loss for LM step.")
            self.mlm_mask_prop = params.mlm_mask_prop
            assert 0.0 <= self.mlm_mask_prop <= 1.0
            assert params.word_mask + params.word_keep + params.word_rand == 1.0
            self.pred_probs = torch.FloatTensor(
                [params.word_mask, params.word_keep, params.word_rand])
            self.pred_probs = self.pred_probs.to(
                f"cuda:{params.local_rank}"
            ) if params.n_gpu > 0 else self.pred_probs
            self.token_probs = token_probs.to(
                f"cuda:{params.local_rank}"
            ) if params.n_gpu > 0 else token_probs
            if self.fp16:
                self.pred_probs = self.pred_probs.half()
                self.token_probs = self.token_probs.half()
        else:
            logger.info("Using CLM loss for LM step.")

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_clm = 0
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        self.lm_loss_fct = nn.CrossEntropyLoss(ignore_index=-100)
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

        logger.info("--- Initializing model optimizer")
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps *
                params.n_epoch) + 1)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                0.0,
            },
        ]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel

                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel

                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )

        self.is_master = params.is_master
        if self.is_master:
            logger.info("--- Initializing Tensorboard")
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, "log", "train"))
            self.tensorboard.add_text(tag="config/training",
                                      text_string=str(self.params),
                                      global_step=0)
            self.tensorboard.add_text(tag="config/student",
                                      text_string=str(self.student_config),
                                      global_step=0)

    def prepare_batch_mlm(self, batch):
        """
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels. There is a -100 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = torch.arange(token_ids.size(1),
                                 dtype=torch.long,
                                 device=lengths.device) < lengths[:, None]

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(),
                                    n_tgt,
                                    replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids["pad_token"]] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(self.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(
            self.params.special_tok_ids["mask_token"])
        probs = torch.multinomial(self.pred_probs,
                                  len(_token_ids_real),
                                  replacement=True)
        _token_ids = (_token_ids_mask * (probs == 0).long() + _token_ids_real *
                      (probs == 1).long() + _token_ids_rand *
                      (probs == 2).long())
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[
            ~pred_mask] = -100  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, mlm_labels

    def prepare_batch_clm(self, batch):
        """
        Prepare the batch: from the token_ids and the lengths, compute the attention mask and the labels for CLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            clm_labels: `torch.tensor(bs, seq_length)` - The causal language modeling labels. There is a -100 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = torch.arange(token_ids.size(1),
                                 dtype=torch.long,
                                 device=lengths.device) < lengths[:, None]
        clm_labels = token_ids.new(token_ids.size()).copy_(token_ids)
        clm_labels[
            ~attn_mask] = -100  # previously `clm_labels[1-attn_mask] = -1`, cf pytorch 1.2.0 compatibility

        # sanity checks
        assert 0 <= token_ids.min() <= token_ids.max() < self.vocab_size

        return token_ids, attn_mask, clm_labels

    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            if self.mlm:
                pad_id = self.params.special_tok_ids["pad_token"]
            else:
                pad_id = self.params.special_tok_ids["unk_token"]
            padding_tensor = torch.zeros(bs2,
                                         pad,
                                         dtype=torch.long,
                                         device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master:
            logger.info("Starting training")
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(
                    f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = tqdm(self.dataloader,
                            desc="-Iter",
                            disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                if self.params.n_gpu > 0:
                    batch = tuple(
                        t.to(f"cuda:{self.params.local_rank}") for t in batch)

                if self.mlm:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_mlm(
                        batch=batch)
                else:
                    token_ids, attn_mask, lm_labels = self.prepare_batch_clm(
                        batch=batch)
                self.step(input_ids=token_ids,
                          attention_mask=attn_mask,
                          lm_labels=lm_labels)

                iter_bar.update()
                iter_bar.set_postfix({
                    "Last_loss":
                    f"{self.last_loss:.2f}",
                    "Avg_cum_loss":
                    f"{self.total_loss_epoch/self.n_iter:.2f}"
                })
            iter_bar.close()

            if self.is_master:
                logger.info(
                    f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
            self.end_epoch()

        if self.is_master:
            logger.info("Save very last checkpoint as `pytorch_model.bin`.")
            self.save_checkpoint(checkpoint_name="pytorch_model.bin")
            logger.info("Training is finished")

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             lm_labels: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """
        if self.mlm:
            student_outputs = self.student(
                input_ids=input_ids,
                attention_mask=attention_mask)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                teacher_outputs = self.teacher(input_ids=input_ids,
                                               attention_mask=attention_mask
                                               )  # (bs, seq_length, voc_size)
        else:
            student_outputs = self.student(
                input_ids=input_ids,
                attention_mask=None)  # (bs, seq_length, voc_size)
            with torch.no_grad():
                teacher_outputs = self.teacher(
                    input_ids=input_ids,
                    attention_mask=None)  # (bs, seq_length, voc_size)
        s_logits, s_hidden_states = student_outputs["logits"], student_outputs[
            "hidden_states"]
        t_logits, t_hidden_states = teacher_outputs["logits"], teacher_outputs[
            "hidden_states"]
        assert s_logits.size() == t_logits.size()

        # https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        # https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
            mask = (lm_labels > -1).unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_length, voc_size)
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_length, voc_size)
        s_logits_slct = torch.masked_select(
            s_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(
            t_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = (self.ce_loss_fct(
            nn.functional.log_softmax(s_logits_slct / self.temperature,
                                      dim=-1),
            nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
        ) * (self.temperature)**2)
        loss = self.alpha_ce * loss_ce

        if self.alpha_mlm > 0.0:
            loss_mlm = self.lm_loss_fct(s_logits.view(-1, s_logits.size(-1)),
                                        lm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_clm > 0.0:
            shift_logits = s_logits[..., :-1, :].contiguous()
            shift_labels = lm_labels[..., 1:].contiguous()
            loss_clm = self.lm_loss_fct(
                shift_logits.view(-1, shift_logits.size(-1)),
                shift_labels.view(-1))
            loss += self.alpha_clm * loss_clm

        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_hidden_states)  # (bs, seq_length, dim)
            assert s_hidden_states.size() == t_hidden_states.size()
            dim = s_hidden_states.size(-1)

            s_hidden_states_slct = torch.masked_select(
                s_hidden_states, mask)  # (bs * seq_length * dim)
            s_hidden_states_slct = s_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)
            t_hidden_states_slct = torch.masked_select(
                t_hidden_states, mask)  # (bs * seq_length * dim)
            t_hidden_states_slct = t_hidden_states_slct.view(
                -1, dim)  # (bs * seq_length, dim)

            target = s_hidden_states_slct.new(
                s_hidden_states_slct.size(0)).fill_(1)  # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct,
                                            t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.0:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_clm > 0.0:
            self.last_loss_clm = loss_clm.item()
        if self.alpha_mse > 0.0:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.0:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                nn.utils.clip_grad_norm_(amp.master_params(self.optimizer),
                                         self.params.max_grad_norm)
            else:
                nn.utils.clip_grad_norm_(self.student.parameters(),
                                         self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag="parameter_mean/" + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="parameter_std/" + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="losses/cum_avg_loss_epoch",
            scalar_value=self.total_loss_epoch / self.n_iter,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mlm > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_mlm",
                                        scalar_value=self.last_loss_mlm,
                                        global_step=self.n_total_iter)
        if self.alpha_clm > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_clm",
                                        scalar_value=self.last_loss_clm,
                                        global_step=self.n_total_iter)
        if self.alpha_mse > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        if self.alpha_cos > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_cos",
                                        scalar_value=self.last_loss_cos,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="global/speed",
                                    scalar_value=time.time() - self.last_log,
                                    global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f"{self.n_sequences_epoch} sequences have been trained during this epoch."
        )

        if self.is_master:
            self.save_checkpoint(
                checkpoint_name=f"model_epoch_{self.epoch}.pth")
            self.tensorboard.add_scalar(tag="epoch/loss",
                                        scalar_value=self.total_loss_epoch /
                                        self.n_iter,
                                        global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(
            self.student, "module") else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--model_type",
                        default=None,
                        type=str,
                        required=True,
                        help="Model type selected in the list: " +
                        ", ".join(MODEL_CLASSES.keys()))
    parser.add_argument(
        "--model_name_or_path",
        default=None,
        type=str,
        required=True,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--meta_path",
        default=None,
        type=str,
        required=False,
        help="Path to pre-trained model or shortcut name selected in the list: "
        + ", ".join(ALL_MODELS))
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--config_name",
        default="",
        type=str,
        help="Pretrained config name or path if not the same as model_name")
    parser.add_argument(
        "--tokenizer_name",
        default="",
        type=str,
        help="Pretrained tokenizer name or path if not the same as model_name")
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after tokenization. Sequences longer "
        "than this will be truncated, sequences shorter will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--predict_eval",
                        action='store_true',
                        help="Whether to predict eval set.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--evaluate_during_training",
        action='store_true',
        help="Rul evaluation during training at each logging step.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")

    parser.add_argument("--per_gpu_train_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for training.")
    parser.add_argument("--per_gpu_eval_batch_size",
                        default=8,
                        type=int,
                        help="Batch size per GPU/CPU for evaluation.")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--weight_decay",
                        default=0.0,
                        type=float,
                        help="Weight deay if we apply some.")
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument("--max_grad_norm",
                        default=1.0,
                        type=float,
                        help="Max gradient norm.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--max_steps",
        default=-1,
        type=int,
        help=
        "If > 0: set total number of training steps to perform. Override num_train_epochs."
    )
    parser.add_argument("--eval_steps", default=-1, type=int, help="")
    parser.add_argument("--lstm_hidden_size", default=300, type=int, help="")
    parser.add_argument("--lstm_layers", default=2, type=int, help="")
    parser.add_argument("--lstm_dropout", default=0.5, type=float, help="")

    parser.add_argument("--train_steps", default=-1, type=int, help="")
    parser.add_argument("--report_steps", default=-1, type=int, help="")
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument("--split_num", default=3, type=int, help="text split")
    parser.add_argument('--logging_steps',
                        type=int,
                        default=50,
                        help="Log every X updates steps.")
    parser.add_argument('--save_steps',
                        type=int,
                        default=50,
                        help="Save checkpoint every X updates steps.")
    parser.add_argument(
        "--eval_all_checkpoints",
        action='store_true',
        help=
        "Evaluate all checkpoints starting with the same prefix as model_name ending and ending with step number"
    )
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Avoid using CUDA when available")
    parser.add_argument('--overwrite_output_dir',
                        action='store_true',
                        help="Overwrite the content of the output directory")
    parser.add_argument(
        '--overwrite_cache',
        action='store_true',
        help="Overwrite the cached training and evaluation sets")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")

    parser.add_argument(
        '--fp16',
        action='store_true',
        help=
        "Whether to use 16-bit (mixed) precision (through NVIDIA apex) instead of 32-bit"
    )
    parser.add_argument(
        '--fp16_opt_level',
        type=str,
        default='O1',
        help=
        "For fp16: Apex AMP optimization level selected in ['O0', 'O1', 'O2', and 'O3']."
        "See details at https://nvidia.github.io/apex/amp.html")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="For distributed training: local_rank")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="For distant debugging.")
    parser.add_argument("--freeze",
                        default=0,
                        type=int,
                        required=False,
                        help="freeze bert.")
    parser.add_argument("--not_do_eval_steps",
                        default=0.5,
                        type=float,
                        help="not_do_eval_steps.")
    args = parser.parse_args()

    # Setup CUDA, GPU & distributed training
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        args.n_gpu = torch.cuda.device_count()
    else:  # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        torch.distributed.init_process_group(backend='nccl')
        args.n_gpu = 1
    args.device = device

    # Setup logging
    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        args.local_rank, device, args.n_gpu, bool(args.local_rank != -1),
        args.fp16)

    # Set seed
    set_seed(args)

    try:
        os.makedirs(args.output_dir)
    except:
        pass

    tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path,
                                              do_lower_case=args.do_lower_case)

    config = BertConfig.from_pretrained(args.model_name_or_path, num_labels=6)

    # Prepare model
    model = BertForSequenceClassification_last2embedding_cls.from_pretrained(
        args.model_name_or_path, args, config=config)

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif args.n_gpu > 1:
        model = torch.nn.DataParallel(model)
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    args.eval_batch_size = args.per_gpu_eval_batch_size * max(1, args.n_gpu)
    if args.do_train:

        # Prepare data loader

        train_examples = read_examples(os.path.join(args.data_dir,
                                                    'train.csv'),
                                       is_training=True)
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      args.max_seq_length,
                                                      args.split_num, True)
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features],
                                 dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size //
                                      args.gradient_accumulation_steps)

        num_train_optimization_steps = args.train_steps

        # Prepare optimizer

        param_optimizer = list(model.named_parameters())

        # hack to remove pooler, which is not used
        # thus it produce None grad that break apex
        param_optimizer = [n for n in param_optimizer]

        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            args.weight_decay
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=args.train_steps //
                                         args.gradient_accumulation_steps)

        global_step = 0

        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        best_acc = 0
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        bar = tqdm(range(num_train_optimization_steps),
                   total=num_train_optimization_steps)
        train_dataloader = cycle(train_dataloader)

        # 先做一个eval
        for file in ['dev.csv']:
            inference_labels = []
            gold_labels = []
            inference_logits = []
            eval_examples = read_examples(os.path.join(args.data_dir, file),
                                          is_training=True)
            eval_features = convert_examples_to_features(
                eval_examples, tokenizer, args.max_seq_length, args.split_num,
                False)
            all_input_ids = torch.tensor(select_field(eval_features,
                                                      'input_ids'),
                                         dtype=torch.long)
            all_input_mask = torch.tensor(select_field(eval_features,
                                                       'input_mask'),
                                          dtype=torch.long)
            all_segment_ids = torch.tensor(select_field(
                eval_features, 'segment_ids'),
                                           dtype=torch.long)
            all_label = torch.tensor([f.label for f in eval_features],
                                     dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_label)

            logger.info("***** Running evaluation *****")
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)

            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    tmp_eval_loss, logits = model(input_ids=input_ids,
                                                  token_type_ids=segment_ids,
                                                  attention_mask=input_mask,
                                                  labels=label_ids)
                    # logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)

                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                inference_labels.append(np.argmax(logits, axis=1))
                gold_labels.append(label_ids)
                inference_logits.append(logits)
                eval_loss += tmp_eval_loss.mean().item()
                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

            gold_labels = np.concatenate(gold_labels, 0)
            inference_logits = np.concatenate(inference_logits, 0)
            model.train()
            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = accuracy(inference_logits, gold_labels)

            result = {
                'eval_loss': eval_loss,
                'eval_F1': eval_accuracy,
                'global_step': global_step
            }

            output_eval_file = os.path.join(args.output_dir,
                                            "eval_results.txt")
            with open(output_eval_file, "a") as writer:
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
                writer.write('*' * 80)
                writer.write('\n')
            if eval_accuracy > best_acc and 'dev' in file:
                print("=" * 80)
                print("Best F1", eval_accuracy)
                print("Saving Model......")
                best_acc = eval_accuracy
                # Save a trained model
                model_to_save = model.module if hasattr(
                    model, 'module') else model  # Only save the model it-self
                output_model_file = os.path.join(args.output_dir,
                                                 "pytorch_model.bin")
                torch.save(model_to_save.state_dict(), output_model_file)
                print("=" * 80)
            else:
                print("=" * 80)

        model.train()

        for step in bar:
            batch = next(train_dataloader)
            batch = tuple(t.to(device) for t in batch)
            input_ids, input_mask, segment_ids, label_ids = batch
            loss, _ = model(input_ids=input_ids,
                            token_type_ids=segment_ids,
                            attention_mask=input_mask,
                            labels=label_ids)
            nb_tr_examples += input_ids.size(0)
            del input_ids, input_mask, segment_ids, label_ids
            if args.n_gpu > 1:
                loss = loss.mean()  # mean() to average on multi-gpu.
            if args.fp16 and args.loss_scale != 1.0:
                loss = loss * args.loss_scale
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
            tr_loss += loss.item()
            train_loss = round(
                tr_loss * args.gradient_accumulation_steps / (nb_tr_steps + 1),
                4)
            bar.set_description("loss {}".format(train_loss))

            nb_tr_steps += 1

            if args.fp16:
                optimizer.backward(loss)
            else:

                loss.backward()

            if (nb_tr_steps + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    # modify learning rate with special warm up BERT uses
                    # if args.fp16 is False, BertAdam is used that handles this automatically
                    lr_this_step = args.learning_rate * warmup_linear.get_lr(
                        global_step, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                optimizer.step()
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1

            if (step + 1) % (args.eval_steps *
                             args.gradient_accumulation_steps) == 0:
                tr_loss = 0
                nb_tr_examples, nb_tr_steps = 0, 0
                logger.info("***** Report result *****")
                logger.info("  %s = %s", 'global_step', str(global_step))
                logger.info("  %s = %s", 'train loss', str(train_loss))

            if args.do_eval and step > num_train_optimization_steps * args.not_do_eval_steps and (
                    step + 1) % (args.eval_steps *
                                 args.gradient_accumulation_steps) == 0:
                for file in ['dev.csv']:
                    inference_labels = []
                    gold_labels = []
                    inference_logits = []
                    eval_examples = read_examples(os.path.join(
                        args.data_dir, file),
                                                  is_training=True)
                    eval_features = convert_examples_to_features(
                        eval_examples, tokenizer, args.max_seq_length,
                        args.split_num, False)
                    all_input_ids = torch.tensor(select_field(
                        eval_features, 'input_ids'),
                                                 dtype=torch.long)
                    all_input_mask = torch.tensor(select_field(
                        eval_features, 'input_mask'),
                                                  dtype=torch.long)
                    all_segment_ids = torch.tensor(select_field(
                        eval_features, 'segment_ids'),
                                                   dtype=torch.long)
                    all_label = torch.tensor([f.label for f in eval_features],
                                             dtype=torch.long)

                    eval_data = TensorDataset(all_input_ids, all_input_mask,
                                              all_segment_ids, all_label)

                    logger.info("***** Running evaluation *****")
                    logger.info("  Num examples = %d", len(eval_examples))
                    logger.info("  Batch size = %d", args.eval_batch_size)

                    # Run prediction for full data
                    eval_sampler = SequentialSampler(eval_data)
                    eval_dataloader = DataLoader(
                        eval_data,
                        sampler=eval_sampler,
                        batch_size=args.eval_batch_size)

                    model.eval()
                    eval_loss, eval_accuracy = 0, 0
                    nb_eval_steps, nb_eval_examples = 0, 0
                    for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)

                        with torch.no_grad():
                            tmp_eval_loss, logits = model(
                                input_ids=input_ids,
                                token_type_ids=segment_ids,
                                attention_mask=input_mask,
                                labels=label_ids)
                            # logits = model(input_ids=input_ids, token_type_ids=segment_ids, attention_mask=input_mask)

                        logits = logits.detach().cpu().numpy()
                        label_ids = label_ids.to('cpu').numpy()
                        inference_labels.append(np.argmax(logits, axis=1))
                        gold_labels.append(label_ids)
                        inference_logits.append(logits)
                        eval_loss += tmp_eval_loss.mean().item()
                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                    gold_labels = np.concatenate(gold_labels, 0)
                    inference_logits = np.concatenate(inference_logits, 0)
                    model.train()
                    eval_loss = eval_loss / nb_eval_steps
                    eval_accuracy = accuracy(inference_logits, gold_labels)

                    result = {
                        'eval_loss': eval_loss,
                        'eval_F1': eval_accuracy,
                        'global_step': global_step,
                        'loss': train_loss
                    }

                    output_eval_file = os.path.join(args.output_dir,
                                                    "eval_results.txt")
                    with open(output_eval_file, "a") as writer:
                        for key in sorted(result.keys()):
                            logger.info("  %s = %s", key, str(result[key]))
                            writer.write("%s = %s\n" % (key, str(result[key])))
                        writer.write('*' * 80)
                        writer.write('\n')
                    if eval_accuracy > best_acc and 'dev' in file:
                        print("=" * 80)
                        print("Best F1", eval_accuracy)
                        print("Saving Model......")
                        best_acc = eval_accuracy
                        # Save a trained model
                        model_to_save = model.module if hasattr(
                            model,
                            'module') else model  # Only save the model it-self
                        output_model_file = os.path.join(
                            args.output_dir, "pytorch_model.bin")
                        torch.save(model_to_save.state_dict(),
                                   output_model_file)
                        print("=" * 80)
                    else:
                        print("=" * 80)
    if args.do_test:
        del model
        gc.collect()
        args.do_train = False
        model = BertForSequenceClassification_last2embedding_cls.from_pretrained(
            os.path.join(args.output_dir, "pytorch_model.bin"),
            args,
            config=config)
        if args.fp16:
            model.half()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        for file, flag in [('dev.csv', 'dev'), ('test.csv', 'test')]:
            inference_labels = []
            gold_labels = []
            eval_examples = read_examples(os.path.join(args.data_dir, file),
                                          is_training=False)
            eval_features = convert_examples_to_features(
                eval_examples, tokenizer, args.max_seq_length, args.split_num,
                False)
            all_input_ids = torch.tensor(select_field(eval_features,
                                                      'input_ids'),
                                         dtype=torch.long)
            all_input_mask = torch.tensor(select_field(eval_features,
                                                       'input_mask'),
                                          dtype=torch.long)
            all_segment_ids = torch.tensor(select_field(
                eval_features, 'segment_ids'),
                                           dtype=torch.long)
            all_label = torch.tensor([f.label for f in eval_features],
                                     dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_label)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(
                        input_ids=input_ids,
                        token_type_ids=segment_ids,
                        attention_mask=input_mask).detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                inference_labels.append(logits)
                gold_labels.append(label_ids)
            gold_labels = np.concatenate(gold_labels, 0)
            logits = np.concatenate(inference_labels, 0)
            print(flag, accuracy(logits, gold_labels))
            if flag == 'test':
                df = pd.read_csv(os.path.join(args.data_dir, file))
                df['label_0'] = logits[:, 0]
                df['label_1'] = logits[:, 1]
                df['label_2'] = logits[:, 2]
                df['label_3'] = logits[:, 3]
                df['label_4'] = logits[:, 4]
                df['label_5'] = logits[:, 5]
                df[[
                    'id', 'label_0', 'label_1', 'label_2', 'label_3',
                    'label_4', 'label_5'
                ]].to_csv(os.path.join(args.output_dir, "sub.csv"),
                          index=False)
            if flag == 'dev':
                df = pd.read_csv(os.path.join(args.data_dir, file))
                df['label_0'] = logits[:, 0]
                df['label_1'] = logits[:, 1]
                df['label_2'] = logits[:, 2]
                df['label_3'] = logits[:, 3]
                df['label_4'] = logits[:, 4]
                df['label_5'] = logits[:, 5]
                df[[
                    'id', 'label_0', 'label_1', 'label_2', 'label_3',
                    'label_4', 'label_5'
                ]].to_csv(os.path.join(args.output_dir, "sub_dev.csv"),
                          index=False)

    if args.predict_eval:
        del model
        gc.collect()
        args.do_train = False
        model = BertForSequenceClassification_last2embedding_cls.from_pretrained(
            os.path.join(args.output_dir, "pytorch_model.bin"),
            args,
            config=config)
        if args.fp16:
            model.half()
        model.to(device)
        if args.local_rank != -1:
            try:
                from apex.parallel import DistributedDataParallel as DDP
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
                )

            model = DDP(model)
        elif args.n_gpu > 1:
            model = torch.nn.DataParallel(model)

        for file, flag in [('dev.csv', 'dev')]:
            inference_labels = []
            gold_labels = []
            eval_examples = read_examples(os.path.join(args.data_dir, file),
                                          is_training=False)
            eval_features = convert_examples_to_features(
                eval_examples, tokenizer, args.max_seq_length, args.split_num,
                False)
            all_input_ids = torch.tensor(select_field(eval_features,
                                                      'input_ids'),
                                         dtype=torch.long)
            all_input_mask = torch.tensor(select_field(eval_features,
                                                       'input_mask'),
                                          dtype=torch.long)
            all_segment_ids = torch.tensor(select_field(
                eval_features, 'segment_ids'),
                                           dtype=torch.long)
            all_label = torch.tensor([f.label for f in eval_features],
                                     dtype=torch.long)

            eval_data = TensorDataset(all_input_ids, all_input_mask,
                                      all_segment_ids, all_label)
            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data,
                                         sampler=eval_sampler,
                                         batch_size=args.eval_batch_size)

            model.eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0
            for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    logits = model(
                        input_ids=input_ids,
                        token_type_ids=segment_ids,
                        attention_mask=input_mask).detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                inference_labels.append(logits)
                gold_labels.append(label_ids)
            gold_labels = np.concatenate(gold_labels, 0)
            logits = np.concatenate(inference_labels, 0)
            print(flag, accuracy(logits, gold_labels))
            if flag == 'dev':
                df = pd.read_csv(os.path.join(args.data_dir, file))
                df['label_0'] = logits[:, 0]
                df['label_1'] = logits[:, 1]
                df['label_2'] = logits[:, 2]
                df['label_3'] = logits[:, 3]
                df['label_4'] = logits[:, 4]
                df['label_5'] = logits[:, 5]
                df[[
                    'id', 'label_0', 'label_1', 'label_2', 'label_3',
                    'label_4', 'label_5'
                ]].to_csv(os.path.join(args.output_dir, "sub_dev.csv"),
                          index=False)
Example #18
0
def main(args=None):
    if args is None:
        args = model_utils.run_redundancy_span_get_local_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    # logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
    #     device, n_gpu, bool(args.local_rank != -1), args.fp16))
    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))
    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_predict:
        raise ValueError(
            "At least one of `do_train` or `do_predict` must be True.")
    if args.do_train:
        if not args.train_file:
            raise ValueError(
                "If `do_train` is True, then `train_file` must be specified.")
    if args.do_predict:
        if not args.predict_file:
            raise ValueError(
                "If `do_predict` is True, then `predict_file` must be specified."
            )
    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory () already exists and is not empty.")
    os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = read_many_examples(input_file=args.train_file,
                                            is_training=True)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    model = BertForQuestionAnswering.from_pretrained(
        args.bert_model,
        cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
        'distributed_{}'.format(args.local_rank))
    print(PYTORCH_PRETRAINED_BERT_CACHE /
          'distributed_{}'.format(args.local_rank))

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)
    global_step = 0
    if args.do_train:
        cached_train_features_file = args.train_file + '_{0}_{1}_{2}_{3}'.format(
            args.bert_model, str(args.max_seq_length), str(args.doc_stride),
            str(args.max_query_length))
        train_features = None
        try:
            with open(cached_train_features_file, "rb") as reader:
                train_features = pickle.load(reader)
        except:
            train_features = convert_examples_to_features(
                examples=train_examples,
                tokenizer=tokenizer,
                max_seq_length=args.max_seq_length,
                doc_stride=args.doc_stride,
                max_query_length=args.max_query_length,
                is_training=True)
            if args.local_rank == -1 or torch.distributed.get_rank() == 0:
                # logger.info("  Saving train features into cached file %s", cached_train_features_file)
                with open(cached_train_features_file, "wb") as writer:
                    pickle.dump(train_features, writer)
        # logger.info("***** Running training *****")
        # logger.info("  Num orig examples = %d", len(train_examples))
        # logger.info("  Num split examples = %d", len(train_features))
        # logger.info("  Batch size = %d", args.train_batch_size)
        # logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_start_positions = torch.tensor(
            [f.start_position for f in train_features], dtype=torch.long)
        all_end_positions = torch.tensor(
            [f.end_position for f in train_features], dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_start_positions,
                                   all_end_positions)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                if n_gpu == 1:
                    batch = tuple(
                        t.to(device)
                        for t in batch)  # multi-gpu does scattering it-self
                input_ids, input_mask, segment_ids, start_positions, end_positions = batch
                loss = model(input_ids=input_ids,
                             token_type_ids=segment_ids,
                             attention_mask=input_mask,
                             start_positions=start_positions,
                             end_positions=end_positions)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward(retain_graph=True)
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * span_utils.warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    # Save a trained model
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model = BertForQuestionAnswering.from_pretrained(
        args.bert_model, state_dict=model_state_dict)
    model.to(device)

    if args.do_predict and (args.local_rank == -1
                            or torch.distributed.get_rank() == 0):
        eval_examples = read_many_examples(input_file=args.predict_file,
                                           is_training=False)
        eval_features = convert_examples_to_features(
            examples=eval_examples,
            tokenizer=tokenizer,
            max_seq_length=args.max_seq_length,
            doc_stride=args.doc_stride,
            max_query_length=args.max_query_length,
            is_training=False)
        # logger.info("***** Running predictions *****")
        # logger.info("  Num orig examples = %d", len(eval_examples))
        # logger.info("  Num split examples = %d", len(eval_features))
        # logger.info("  Batch size = %d", args.predict_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        # all_start_positions = torch.tensor([f.start_position for f in train_features], dtype=torch.long)
        # all_end_positions = torch.tensor([f.end_position for f in train_features], dtype=torch.long)
        all_example_index = torch.arange(all_input_ids.size(0),
                                         dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_example_index)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.predict_batch_size)

        model.eval()
        all_results = []
        # logger.info("Start evaluating")
        for input_ids, input_mask, segment_ids, example_indices, in tqdm(
                eval_dataloader, desc="Evaluating"):
            # if len(all_results) % 1000 == 0: logger.info("Processing example: %d" % (len(all_results)))

            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)

            with torch.no_grad():
                batch_start_logits, batch_end_logits = model(
                    input_ids=input_ids,
                    token_type_ids=segment_ids,
                    attention_mask=input_mask)

            for i, example_index in enumerate(example_indices):
                start_logits = batch_start_logits[i].detach().cpu().tolist()
                end_logits = batch_end_logits[i].detach().cpu().tolist()

                eval_feature = eval_features[example_index.item()]
                # unique_id = int(eval_feature.unique_id)
                all_results.append(
                    RawResult(unique_id=int(eval_feature.unique_id),
                              start_logits=start_logits,
                              end_logits=end_logits))

        output_prediction_file = os.path.join(args.output_dir,
                                              "predictions.json")
        output_nbest_file = os.path.join(args.output_dir,
                                         "nbest_predictions.json")
        write_predictions(eval_examples, eval_features, all_results,
                          args.n_best_size, args.max_answer_length,
                          args.do_lower_case, output_prediction_file,
                          output_nbest_file, args.verbose_logging)
def main():
    # Arguments
    parser = classifier_args.get_base_parser()
    classifier_args.training_args(parser)
    classifier_args.fp16_args(parser)
    classifier_args.pruning_args(parser)
    classifier_args.eval_args(parser)
    classifier_args.analysis_args(parser)

    args = parser.parse_args()

    # ==== CHECK ARGS AND SET DEFAULTS ====

    if args.dry_run:
        args = prepare_dry_run(args)

    if args.gradient_accumulation_steps < 1:
        raise ValueError(f"Invalid gradient_accumulation_steps parameter: "
                         f"{args.gradient_accumulation_steps}, should be >= 1")

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    if not (args.do_train or args.do_eval or args.do_prune or args.do_anal):
        raise ValueError(
            "At least one of `do_train`, `do_eval` or `do_prune` must be True."
        )
    out_dir_exists = os.path.exists(args.output_dir) and \
        os.listdir(args.output_dir)
    if out_dir_exists and args.do_train and not args.overwrite:
        raise ValueError(
            f"Output directory ({args.output_dir}) already exists and is not "
            "empty.")

    if args.n_retrain_steps_after_pruning > 0 and args.retrain_pruned_heads:
        raise ValueError(
            "--n_retrain_steps_after_pruning and --retrain_pruned_heads are "
            "mutually exclusive")

    # ==== SETUP DEVICE ====

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of
        # sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(f"device: {device} n_gpu: {n_gpu}, "
                f"distributed training: {bool(args.local_rank != -1)}, "
                f"16-bits training: {args.fp16}")

    # ==== SETUP EXPERIMENT ====

    def set_seeds(seed, n_gpu):
        random.seed(seed)
        np.random.seed(seed)
        torch.manual_seed(seed)
        if n_gpu > 0:
            torch.cuda.manual_seed_all(seed)

    set_seeds(args.seed, n_gpu)

    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in data.processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = data.processors[task_name]()
    num_labels = data.num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    # ==== PREPARE DATA ====

    # Train data
    if args.do_train or args.do_prune:
        # Prepare training data
        if args.dry_run:
            train_examples = processor.get_dummy_train_examples(args.data_dir)
        else:
            train_examples = processor.get_train_examples(args.data_dir)
        train_data = data.prepare_tensor_dataset(
            train_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            verbose=args.verbose,
        )

    # Eval data
    if args.do_eval or (args.do_prune and args.eval_pruned):
        if args.dry_run:
            eval_examples = processor.get_dummy_dev_examples(args.data_dir)
        else:
            eval_examples = processor.get_dev_examples(args.data_dir)
        # data.add_dependency_arcs(eval_examples)
        # print(eval_examples[-2].parse_a)
        eval_data = data.prepare_tensor_dataset(
            eval_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            verbose=args.verbose,
        )

    # ==== PREPARE MODEL ====

    def get_model(model_type,
                  toy_classifier=False,
                  dry_run=False,
                  n_heads=1,
                  state_dict=None,
                  cache_dir=None):

        if dry_run:
            model = BertForSequenceClassification(BertConfig.dummy_config(
                len(tokenizer.vocab)),
                                                  num_labels=num_labels)
        else:
            model = BertForSequenceClassification.from_pretrained(
                model_type,
                cache_dir=cache_dir,
                num_labels=num_labels,
                state_dict=None if toy_classifier else state_dict,
            )
        if toy_classifier:
            config = BertConfig(len(tokenizer.vocab),
                                hidden_size=768,
                                num_hidden_layers=1,
                                num_attention_heads=n_heads,
                                intermediate_size=3072,
                                hidden_act="gelu",
                                hidden_dropout_prob=0.1,
                                attention_probs_dropout_prob=0.1,
                                max_position_embeddings=512,
                                type_vocab_size=2,
                                initializer_range=0.02)
            toy_model = BertForSequenceClassification(config,
                                                      num_labels=num_labels)
            toy_model.bert.embeddings.load_state_dict(
                model.bert.embeddings.state_dict())
            if state_dict is not None:
                model_to_load = getattr(toy_model, "module", toy_model)
                model_to_load.load_state_dict(state_dict)
            model = toy_model

        return model

    model = get_model(
        args.bert_model,
        toy_classifier=args.toy_classifier,
        dry_run=args.dry_run,
        n_heads=args.toy_classifier_n_heads,
        cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
        f"distributed_{args.local_rank}",
    )
    # Head dropout
    for layer in model.bert.encoder.layer:
        layer.attention.self.dropout.p = args.attn_dropout

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex "
                "to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Parse pruning descriptor
    to_prune = pruning.parse_head_pruning_descriptors(
        args.attention_mask_heads,
        reverse_descriptors=args.reverse_head_mask,
        n_heads=12)
    # Mask heads
    if args.actually_prune:
        model.bert.prune_heads(to_prune)
    else:
        model.bert.mask_heads(to_prune)

    # ==== PREPARE TRAINING ====

    # Trainable parameters
    if args.do_train or args.do_prune:
        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        # Only train the classifier in feature mode
        if args.feature_mode:
            param_optimizer = [(n, p) for n, p in param_optimizer
                               if n.startswith("classifier")]
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
    # Prepare optimizer for fine-tuning on task
    if args.do_train:
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        optimizer, lr_schedule = training.prepare_bert_adam(
            optimizer_grouped_parameters,
            args.learning_rate,
            num_train_steps,
            args.warmup_proportion,
            loss_scale=args.loss_scale,
            local_rank=args.local_rank,
            fp16=args.fp16,
        )

    # ==== TRAIN ====
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        # Train
        global_step, tr_loss, nb_tr_steps = training.train(
            train_data,
            model,
            optimizer,
            args.train_batch_size,
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            device=device,
            verbose=True,
            disable_progress_bar=args.no_progress_bars,
            n_gpu=n_gpu,
            global_step=global_step,
            lr_schedule=lr_schedule,
            n_epochs=args.num_train_epochs,
            local_rank=args.local_rank,
            fp16=args.fp16,
        )

    # Save train loss
    result = {
        "global_step": global_step,
        "loss": tr_loss / nb_tr_steps if args.do_train else None
    }

    # Save a trained model
    # Only save the model it-self
    model_to_save = getattr(model, "module", model)
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    if args.do_train:
        torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model = get_model(
        args.bert_model,
        toy_classifier=args.toy_classifier,
        dry_run=args.dry_run,
        n_heads=args.toy_classifier_n_heads,
        state_dict=model_state_dict,
    )
    model.to(device)

    is_main = args.local_rank == -1 or torch.distributed.get_rank() == 0

    # Parse pruning descriptor
    to_prune = pruning.parse_head_pruning_descriptors(
        args.attention_mask_heads,
        reverse_descriptors=args.reverse_head_mask,
        n_heads=12)
    # Mask heads
    if args.actually_prune:
        model.bert.prune_heads(to_prune)
    else:
        model.bert.mask_heads(to_prune)

    # ==== PRUNE ====
    if args.do_prune and is_main:
        if args.fp16:
            raise NotImplementedError("FP16 is not yet supported for pruning")

        # Determine the number of heads to prune
        prune_sequence = pruning.determine_pruning_sequence(
            args.prune_number,
            args.prune_percent,
            model.bert.config.num_hidden_layers,
            model.bert.config.num_attention_heads,
            args.at_least_x_heads_per_layer,
        )
        # Prepare optimizer for tuning after pruning
        if args.n_retrain_steps_after_pruning > 0:
            retrain_optimizer = SGD(model.parameters(),
                                    lr=args.retrain_learning_rate)
        elif args.retrain_pruned_heads:
            if args.n_retrain_steps_pruned_heads > 0:
                num_retrain_steps = args.n_retrain_steps_pruned_heads
            else:
                num_retrain_steps = int(
                    len(train_examples) / args.train_batch_size /
                    args.gradient_accumulation_steps) * args.num_train_epochs

        to_prune = {}
        for step, n_to_prune in enumerate(prune_sequence):

            if step == 0 or args.exact_pruning:
                # Calculate importance scores for each layer
                head_importance = calculate_head_importance(
                    model,
                    train_data,
                    batch_size=args.train_batch_size,
                    device=device,
                    normalize_scores_by_layer=args.normalize_pruning_by_layer,
                    subset_size=args.compute_head_importance_on_subset,
                    verbose=True,
                    disable_progress_bar=args.no_progress_bars,
                )
                logger.info("Head importance scores")
                for layer in range(len(head_importance)):
                    layer_scores = head_importance[layer].cpu().data
                    logger.info("\t".join(f"{x:.5f}" for x in layer_scores))
            # Determine which heads to prune
            to_prune = pruning.what_to_prune(
                head_importance,
                n_to_prune,
                to_prune={} if args.retrain_pruned_heads else to_prune,
                at_least_x_heads_per_layer=args.at_least_x_heads_per_layer)
            # Actually mask the heads
            if args.actually_prune:
                model.bert.prune_heads(to_prune)
            else:
                model.bert.mask_heads(to_prune)
            # Maybe continue training a bit
            if args.n_retrain_steps_after_pruning > 0:
                set_seeds(args.seed + step + 1, n_gpu)
                training.train(
                    train_data,
                    model,
                    retrain_optimizer,
                    args.train_batch_size,
                    n_steps=args.n_retrain_steps_after_pruning,
                    device=device,
                )
            elif args.retrain_pruned_heads:
                set_seeds(args.seed + step + 1, n_gpu)
                # Reload BERT
                base_bert = None
                if args.reinit_from_pretrained:
                    base_bert = BertForSequenceClassification.from_pretrained(  # noqa
                        args.bert_model,
                        cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
                        f"distributed_{args.local_rank}",
                        num_labels=num_labels).bert
                    base_bert.to(device)
                # Reinit
                model.bert.reset_heads(to_prune, base_bert)
                # Unmask heads
                model.bert.clear_heads_mask()
                if args.only_retrain_val_out:
                    self_att_params = [
                        p for layer in model.bert.encoder.layer
                        for p in layer.attention.self.value.parameters()
                    ]
                else:
                    self_att_params = [
                        p for layer in model.bert.encoder.layer
                        for p in layer.attention.self.parameters()
                    ]
                head_grouped_parameters = {
                    'params':
                    self_att_params + [
                        p for layer in model.bert.encoder.layer
                        for p in layer.attention.output.dense.parameters()
                    ],
                    'weight_decay':
                    0.01
                }
                retrain_optimizer, lr_schedule = training.prepare_bert_adam(
                    [head_grouped_parameters],
                    args.learning_rate,
                    num_retrain_steps,
                    args.warmup_proportion,
                    loss_scale=args.loss_scale,
                    local_rank=args.local_rank,
                    fp16=args.fp16,
                )
                training.train(
                    train_data,
                    model,
                    retrain_optimizer,
                    args.train_batch_size,
                    gradient_accumulation_steps=args.
                    gradient_accumulation_steps,  # noqa
                    device=device,
                    verbose=True,
                    disable_progress_bar=args.no_progress_bars,
                    n_gpu=n_gpu,
                    global_step=0,
                    lr_schedule=lr_schedule,
                    n_epochs=args.num_train_epochs,
                    local_rank=args.local_rank,
                    fp16=args.fp16,
                    mask_heads_grad=to_prune,
                    n_steps=num_retrain_steps,
                    eval_mode=args.no_dropout_in_retraining,
                )

            # Evaluate
            if args.eval_pruned:
                # Print the pruning descriptor
                logger.info("Evaluating following pruning strategy")
                logger.info(pruning.to_pruning_descriptor(to_prune))
                # Eval accuracy
                metric = processor.scorer.name
                accuracy = evaluate(
                    eval_data,
                    model,
                    args.eval_batch_size,
                    save_attention_probs=args.save_attention_probs,
                    print_head_entropy=True,
                    device=device,
                    verbose=True,
                    disable_progress_bar=args.no_progress_bars,
                    scorer=processor.scorer,
                )[metric]
                logger.info("***** Pruning eval results *****")
                tot_pruned = sum(len(heads) for heads in to_prune.values())
                logger.info(f"{tot_pruned}\t{accuracy}")

    # ==== EVALUATE ====
    if args.do_eval and is_main:
        evaluate(
            eval_data,
            model,
            args.eval_batch_size,
            save_attention_probs=args.save_attention_probs,
            print_head_entropy=False,
            device=device,
            result=result,
            disable_progress_bar=args.no_progress_bars,
            scorer=processor.scorer,
        )
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    # ==== ANALYZIS ====
    if args.do_anal:
        if not data.is_nli_task(processor):
            logger.warn(
                f"You are running analysis on the NLI diagnostic set but the "
                f"task ({args.task_name}) is not NLI")
        anal_processor = data.DiagnosticProcessor()
        if args.dry_run:
            anal_examples = anal_processor.get_dummy_dev_examples(
                args.anal_data_dir)
        else:
            anal_examples = anal_processor.get_dev_examples(args.anal_data_dir)
        anal_data = data.prepare_tensor_dataset(
            anal_examples,
            label_list,
            args.max_seq_length,
            tokenizer,
            verbose=args.verbose,
        )
        predictions = predict(
            anal_data,
            model,
            args.eval_batch_size,
            verbose=True,
            disable_progress_bar=args.no_progress_bars,
            device=device,
        )
        report = analyze_nli(anal_examples, predictions, label_list)
        # Print report
        for feature, values in report.items():
            print("=" * 80)
            print(f"Scores breakdown for feature: {feature}")
            for value, accuracy in values.items():
                print(f"{value}\t{accuracy:.5f}")
Example #20
0
class Distiller:
    def __init__(self, params: dict, dataset: CaptionTSVDataset,
                 student: nn.Module, teacher: nn.Module, val_dataset,
                 tokenizer):
        logger.info("Initializing Distiller")
        self.params = params
        self.dump_path = params.output_dir
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.student_config = student.config
        self.vocab_size = student.config.vocab_size

        if params.n_gpu <= 1:
            sampler = RandomSampler(dataset)
        else:
            sampler = DistributedSampler(dataset)

        # if params.group_by_size:
        #     groups = create_lengths_groups(lengths=dataset.lengths, k=params.max_model_input_size)
        #     sampler = GroupedBatchSampler(sampler=sampler, group_ids=groups, batch_size=params.batch_size)
        # else:
        #     sampler = BatchSampler(sampler=sampler, batch_size=params.batch_size, drop_last=False)

        sampler = BatchSampler(sampler=sampler,
                               batch_size=params.batch_size,
                               drop_last=False)

        self.dataloader = DataLoader(dataset=dataset, batch_sampler=sampler)
        self.val_dataset = val_dataset
        self.tokenizer = tokenizer

        self.eval_log = []

        self.temperature = params.temperature
        assert self.temperature > 0.0

        self.alpha_ce = params.alpha_ce
        self.alpha_mse = params.alpha_mse
        self.alpha_cos = params.alpha_cos

        # self.mlm = params.mlm
        # if self.mlm:
        #     logger.info("Using MLM loss for LM step.")
        #     self.mlm_mask_prop = params.mlm_mask_prop
        #     assert 0.0 <= self.mlm_mask_prop <= 1.0
        #     assert params.word_mask + params.word_keep + params.word_rand == 1.0
        #     self.pred_probs = torch.FloatTensor([params.word_mask, params.word_keep, params.word_rand])
        #     self.pred_probs = self.pred_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else self.pred_probs
        #     self.token_probs = token_probs.to(f"cuda:{params.local_rank}") if params.n_gpu > 0 else token_probs
        #     if self.fp16:
        #         self.pred_probs = self.pred_probs.half()
        #         self.token_probs = self.token_probs.half()
        # else:
        #     logger.info("Using CLM loss for LM step.")

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        if self.alpha_mse > 0.0:
            self.last_loss_mse = 0
        if self.alpha_cos > 0.0:
            self.last_loss_cos = 0
        self.last_log = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction="batchmean")
        if self.alpha_mse > 0.0:
            self.mse_loss_fct = nn.MSELoss(reduction="sum")
        if self.alpha_cos > 0.0:
            self.cosine_loss_fct = nn.CosineEmbeddingLoss(reduction="mean")

        logger.info("--- Initializing model optimizer")
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = len(self.dataloader)
        num_train_optimization_steps = (
            int(self.num_steps_epoch / params.gradient_accumulation_steps *
                params.n_epoch) + 1)

        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if not any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                params.weight_decay,
            },
            {
                "params": [
                    p for n, p in student.named_parameters()
                    if any(nd in n for nd in no_decay) and p.requires_grad
                ],
                "weight_decay":
                0.0,
            },
        ]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))

        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer,
            num_warmup_steps=warmup_steps,
            num_training_steps=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel

                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel

                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank,
                    find_unused_parameters=True,
                )

        # self.is_master = params.is_master
        # if self.is_master:
        logger.info("--- Initializing Tensorboard")
        self.tensorboard = SummaryWriter(
            log_dir=os.path.join(self.dump_path, "log", "train"))
        self.tensorboard.add_text(tag="config/training",
                                  text_string=str(self.params),
                                  global_step=0)
        self.tensorboard.add_text(tag="config/student",
                                  text_string=str(self.student_config),
                                  global_step=0)

    def train(self):
        """
        The real training loop.
        """
        logger.info("Starting training")
        self.last_log = time.time()
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            logger.info(
                f"--- Starting epoch {self.epoch}/{self.params.n_epoch-1}")
            if self.multi_gpu:
                torch.distributed.barrier()

            iter_bar = tqdm(self.dataloader,
                            desc="-Iter",
                            disable=self.params.local_rank not in [-1, 0])
            for batch in iter_bar:
                if self.params.n_gpu > 0:
                    img_key, example = batch
                    # img_key = img_key.to(f"cuda:{self.params.local_rank}")
                    example = tuple(
                        t.to(f"cuda:{self.params.local_rank}")
                        for t in example)
                '''CaptionTSVDataset:
                def __getitem__(self, idx):
                        img_idx = self.get_image_index(idx)
                        img_key = self.image_keys[img_idx]
                        features = self.get_image_features(img_idx)
                        caption = self.get_caption(idx)
                        od_labels = self.get_od_labels(img_idx)
                        example = self.tensorizer.tensorize_example(caption, features, text_b=od_labels)
                        return img_key, example
                '''

                # example: (input_ids, attention_mask, segment_ids, img_feat, masked_pos)

                inputs = {
                    'input_ids': example[0],
                    'attention_mask': example[1],
                    'token_type_ids': example[2],
                    'img_feats': example[3],
                    'masked_pos': example[4],
                    'masked_ids': example[5]
                }
                outputs = self.step(**inputs)

                iter_bar.update()
                iter_bar.set_postfix({
                    "Last_loss":
                    f"{self.last_loss:.2f}",
                    "Avg_cum_loss":
                    f"{self.total_loss_epoch/self.n_iter:.2f}"
                })
            iter_bar.close()

            logger.info(
                f"--- Ending epoch {self.epoch}/{self.params.n_epoch-1}")
            self.end_epoch()

        logger.info("Save very last checkpoint as `pytorch_model.bin`.")
        self.save_checkpoint(checkpoint_name="pytorch_model.bin")
        logger.info("Training is finished")

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             token_type_ids: torch.tensor, img_feats: torch.tensor,
             masked_pos: torch.tensor, masked_ids: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        lm_labels: `torch.tensor(bs, seq_length)` - The language modeling labels (mlm labels for MLM and clm labels for CLM).
        """

        s_logits, s_hidden_states = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask,
            img_feats=img_feats,
            masked_pos=masked_pos,
            masked_ids=masked_ids,
            token_type_ids=token_type_ids)  # (bs, seq_length, voc_size)
        with torch.no_grad():
            t_output = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask,
                img_feats=img_feats,
                masked_pos=masked_pos,
                masked_ids=masked_ids,
                token_type_ids=token_type_ids)  # (bs, seq_length, voc_size)
            _, t_logits, t_hidden_states = t_output

        # output shape (num_blanks, voc_size)

        # mask = attention_mask.unsqueeze(-1).expand_as(s_logits)  # (bs, seq_length, voc_size)
        # s_logits_slct = torch.masked_select(s_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        # s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        # t_logits_slct = torch.masked_select(t_logits, mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        # t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))  # (bs * seq_length, voc_size) modulo the 1s in mask

        s_logits_slct = s_logits
        t_logits_slct = t_logits
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = (self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature, dim=-1),
        ) * (self.temperature)**2)
        loss = self.alpha_ce * loss_ce

        if self.alpha_mse > 0.0:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse
        if self.alpha_cos > 0.0:
            s_hidden_states = s_hidden_states[-1]  # (bs, seq_length, dim)
            t_hidden_states = t_hidden_states[-1]  # (bs, seq_length, dim)
            # mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states)  # (bs, seq_length, dim)
            # assert s_hidden_states.size() == t_hidden_states.size()
            # dim = s_hidden_states.size(-1)

            # s_hidden_states_slct = torch.masked_select(s_hidden_states, mask)  # (bs * seq_length * dim)
            # s_hidden_states_slct = s_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)
            # t_hidden_states_slct = torch.masked_select(t_hidden_states, mask)  # (bs * seq_length * dim)
            # t_hidden_states_slct = t_hidden_states_slct.view(-1, dim)  # (bs * seq_length, dim)

            s_hidden_states_slct = s_hidden_states.reshape(1, -1)
            t_hidden_states_slct = t_hidden_states.reshape(1, -1)

            target = torch.ones(s_hidden_states_slct.shape).to(
                s_hidden_states_slct.device)  # (bs * seq_length,)
            loss_cos = self.cosine_loss_fct(s_hidden_states_slct,
                                            t_hidden_states_slct, target)
            loss += self.alpha_cos * loss_cos

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mse > 0.0:
            self.last_loss_mse = loss_mse.item()
        if self.alpha_cos > 0.0:
            self.last_loss_cos = loss_cos.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error("NaN detected")
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp

            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer),
                    self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                               self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
            self.last_log = time.time()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()
            logger.info("Perform evaluation at step: %d" % (self.n_total_iter))
            try:
                evaluate_file = evaluate(self.params, self.val_dataset,
                                         self.student, self.tokenizer,
                                         self.dump_path)
                with open(evaluate_file, 'r') as f:
                    res = json.load(f)
                best_score = max(best_score, res['CIDEr'])
                res['epoch'] = epoch
                res['global_step'] = step
                res['best_CIDEr'] = best_score
                self.eval_log.append(res)
                with open(self.dump_path + '/eval_logs.json', 'w') as f:
                    json.dump(eval_log, f)
            except:
                print("An exception was made in the evaluation process. ")

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        # if not self.is_master:
        #     return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag="parameter_mean/" + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="parameter_std/" + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="losses/cum_avg_loss_epoch",
            scalar_value=self.total_loss_epoch / self.n_iter,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mse > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        if self.alpha_cos > 0.0:
            self.tensorboard.add_scalar(tag="losses/loss_cos",
                                        scalar_value=self.last_loss_cos,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()["used"] / 1_000_000,
            global_step=self.n_total_iter,
        )
        self.tensorboard.add_scalar(tag="global/speed",
                                    scalar_value=time.time() - self.last_log,
                                    global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f"{self.n_sequences_epoch} sequences have been trained during this epoch."
        )

        self.save_checkpoint(checkpoint_name=f"model_epoch_{self.epoch}.pth")
        self.tensorboard.add_scalar(tag="epoch/loss",
                                    scalar_value=self.total_loss_epoch /
                                    self.n_iter,
                                    global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = "checkpoint.pth"):
        """
        Save the current state. Only by the master process.
        """
        # if not self.is_master:
        #     return
        mdl_to_save = self.student.module if hasattr(
            self.student, "module") else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--train_file",
                        default=None,
                        type=str,
                        required=True,
                        help="The input train corpus.")
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument('--do_explain',
                        action='store_true',
                        help='Whether to run explanation')
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--on_memory",
        action='store_true',
        help="Whether to load train samples into memory or use disk")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    # if not args.do_train:
    #     raise ValueError("Training is currently the only implemented execution option. Please set `do_train`.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    #train_examples = None
    num_train_optimization_steps = None
    if args.do_explain:
        print("Loading Train Dataset", args.train_file)
        train_dataset = BERTDataset(args.train_file,
                                    tokenizer,
                                    seq_len=args.max_seq_length,
                                    corpus_lines=None,
                                    on_memory=args.on_memory)
        num_train_optimization_steps = int(
            len(train_dataset) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    model = BertForPreTraining.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    #elif n_gpu > 1:
    #    model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    model.to(device)
    if args.do_explain:
        logger.info("***** Running Explaination *****")
        logger.info("  Num examples = %d", len(train_dataset))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)

        if args.local_rank == -1:
            train_sampler = RandomSampler(train_dataset)
        else:
            #TODO: check if this works with current data generator from disk that relies on next(file)
            # (it doesn't return item back by index)
            train_sampler = DistributedSampler(train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        model.eval()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                build_tree_by_leave_out(model, input_ids, segment_ids,
                                        input_mask, lm_label_ids, tokenizer)
                # predict_and_explain_unbatched(model, input_ids, segment_ids, input_mask, lm_label_ids, is_next, tokenizer)
                # loss = model(input_ids, segment_ids, input_mask, lm_label_ids, is_next)
                # if n_gpu > 1:
                #     loss = loss.mean() # mean() to average on multi-gpu.
                # if args.gradient_accumulation_steps > 1:
                #     loss = loss / args.gradient_accumulation_steps
                # if args.fp16:
                #     optimizer.backward(loss)
                # else:
                #     loss.backward()
                # tr_loss += loss.item()
                # nb_tr_examples += input_ids.size(0)
                # nb_tr_steps += 1
                # if (step + 1) % args.gradient_accumulation_steps == 0:
                #     if args.fp16:
                #         # modify learning rate with special warm up BERT uses
                #         # if args.fp16 is False, BertAdam is used that handles this automatically
                #         lr_this_step = args.learning_rate * warmup_linear(global_step/num_train_optimization_steps, args.warmup_proportion)
                #         for param_group in optimizer.param_groups:
                #             param_group['lr'] = lr_this_step
                #     optimizer.step()
                #     optimizer.zero_grad()
                #     global_step += 1

        # Save a trained model
        # logger.info("** ** * Saving fine - tuned model ** ** * ")
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
Example #22
0
def main():
    parser = ArgumentParser()
    parser.add_argument("--pregenerated_data", type=Path, required=True)
    parser.add_argument("--output_dir", type=Path, required=True)
    parser.add_argument(
        "--bert_model",
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.",
    )
    parser.add_argument("--do_lower_case", action="store_true")
    parser.add_argument(
        "--reduce_memory",
        action="store_true",
        help=
        "Store training data as on-disc memmaps to massively reduce memory usage",
    )

    parser.add_argument("--epochs",
                        type=int,
                        default=3,
                        help="Number of epochs to train for")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument("--no_cuda",
                        action="store_true",
                        help="Whether not to use CUDA when available")
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument(
        "--fp16",
        action="store_true",
        help="Whether to use 16-bit float precision instead of 32-bit",
    )
    parser.add_argument(
        "--loss_scale",
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n",
    )
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.",
    )
    parser.add_argument("--learning_rate",
                        default=3e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--seed",
                        type=int,
                        default=42,
                        help="random seed for initialization")
    args = parser.parse_args()

    assert (
        args.pregenerated_data.is_dir()
    ), "--pregenerated_data should point to the folder of files made by pregenerate_training_data.py!"

    samples_per_epoch = []
    for i in range(args.epochs):
        epoch_file = args.pregenerated_data / f"epoch_{i}.json"
        metrics_file = args.pregenerated_data / f"epoch_{i}_metrics.json"
        if epoch_file.is_file() and metrics_file.is_file():
            metrics = json.loads(metrics_file.read_text())
            samples_per_epoch.append(metrics["num_training_examples"])
        else:
            if i == 0:
                exit("No training data was found!")
            print(
                f"Warning! There are fewer epochs of pregenerated data ({i}) than training epochs ({args.epochs})."
            )
            print(
                "This script will loop over the available data, but training diversity may be negatively impacted."
            )
            num_data_epochs = i
            break
    else:
        num_data_epochs = args.epochs

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend="nccl")
    logging.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if args.output_dir.is_dir() and list(args.output_dir.iterdir()):
        logging.warning(
            f"Output directory ({args.output_dir}) already exists and is not empty!"
        )
    args.output_dir.mkdir(parents=True, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    total_train_examples = 0
    for i in range(args.epochs):
        # The modulo takes into account the fact that we may loop over limited epochs of data
        total_train_examples += samples_per_epoch[i % len(samples_per_epoch)]

    num_train_optimization_steps = int(total_train_examples /
                                       args.train_batch_size /
                                       args.gradient_accumulation_steps)
    if args.local_rank != -1:
        num_train_optimization_steps = (num_train_optimization_steps //
                                        torch.distributed.get_world_size())

    # Prepare model
    model = BertForPreTraining.from_pretrained(args.bert_model)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )
        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_grouped_parameters = [
        {
            "params": [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            "weight_decay":
            0.01,
        },
        {
            "params":
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            "weight_decay":
            0.0,
        },
    ]

    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            bias_correction=False,
            max_grad_norm=1.0,
        )
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
        warmup_linear = WarmupLinearSchedule(
            warmup=args.warmup_proportion,
            t_total=num_train_optimization_steps)
    else:
        optimizer = AdamW(
            optimizer_grouped_parameters,
            lr=args.learning_rate,
            warmup=args.warmup_proportion,
            t_total=num_train_optimization_steps,
        )

    global_step = 0
    logging.info("***** Running training *****")
    logging.info(f"  Num examples = {total_train_examples}")
    logging.info("  Batch size = %d", args.train_batch_size)
    logging.info("  Num steps = %d", num_train_optimization_steps)
    model.train()
    for epoch in range(args.epochs):
        epoch_dataset = PregeneratedDataset(
            epoch=epoch,
            training_path=args.pregenerated_data,
            tokenizer=tokenizer,
            num_data_epochs=num_data_epochs,
            reduce_memory=args.reduce_memory,
        )
        if args.local_rank == -1:
            train_sampler = RandomSampler(epoch_dataset)
        else:
            train_sampler = DistributedSampler(epoch_dataset)
        train_dataloader = DataLoader(epoch_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        tr_loss = 0
        nb_tr_examples, nb_tr_steps = 0, 0
        with tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}") as pbar:
            for step, batch in enumerate(train_dataloader):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, lm_label_ids, is_next = batch
                loss = model(input_ids, segment_ids, input_mask, lm_label_ids,
                             is_next)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                pbar.update(1)
                mean_loss = tr_loss * args.gradient_accumulation_steps / nb_tr_steps
                pbar.set_postfix_str(f"Loss: {mean_loss:.5f}")
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step, args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group["lr"] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    # Save a trained model
    logging.info("** ** * Saving fine-tuned model ** ** * ")
    model_to_save = (model.module if hasattr(model, "module") else model
                     )  # Only save the model it-self

    output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
    output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

    torch.save(model_to_save.state_dict(), output_model_file)
    model_to_save.config.to_json_file(output_config_file)
    tokenizer.save_vocabulary(args.output_dir)
Example #23
0
def main():
    parser = argparse.ArgumentParser()
    ## Required parameters
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        choices=["WSD"],
                        help="The name of the task to train.")
    parser.add_argument(
        "--train_data_dir",
        default=None,
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--eval_data_dir",
        default=None,
        type=str,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument("--label_data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The label data dir. (./wordnet)")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help='''a path or url to a pretrained model archive containing:
                        'bert_config.json' a configuration file for the model
                        'pytorch_model.bin' a PyTorch dump of a BertForPreTraining instance'''
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run test on the test set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help=
        "Whether to lower case the input text. True for uncased models, False for cased models."
    )
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumualte before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_test:
        raise ValueError(
            "At least one of `do_train` or `do_test` must be True.")
    if args.do_train:
        assert args.train_data_dir != None, "train_data_dir can not be None"
    if args.do_eval:
        assert args.eval_data_dir != None, "eval_data_dir can not be None"

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    # prepare dataloaders
    processors = {"WSD": WSD_token_Processor}

    output_modes = {"WSD": "classification"}

    processor = processors[args.task_name]()
    output_mode = output_modes[args.task_name]
    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    # training set
    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.train_data_dir,
                                                      args.label_data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    model = BertForTokenClassification.from_pretrained(args.bert_model,
                                                       cache_dir=cache_dir,
                                                       num_labels=num_labels)

    if args.fp16:
        model.half()
    model.to(device)

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    # load data
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_target_mask = torch.tensor([f.target_mask for f in train_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids,
                                   all_target_mask)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

    if args.do_eval:
        eval_examples = processor.get_dev_examples(args.eval_data_dir,
                                                   args.label_data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_target_mask = torch.tensor([f.target_mask for f in eval_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids,
                                  all_target_mask)
        eval_dataloader = DataLoader(eval_data,
                                     batch_size=args.eval_batch_size,
                                     shuffle=False)

    # train
    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0

    if args.do_train:
        model.train()
        epoch = 0
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            epoch += 1
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids, target_mask = batch

                logits = model(input_ids=input_ids,
                               token_type_ids=segment_ids,
                               attention_mask=input_mask,
                               labels=None,
                               target_mask=target_mask)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels),
                                    label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

            # Save a trained model, configuration and tokenizer
            model_to_save = model.module if hasattr(
                model, 'module') else model  # Only save the model it-self

            # If we save using the predefined names, we can load using `from_pretrained`
            model_output_dir = os.path.join(args.output_dir, str(epoch))
            if not os.path.exists(model_output_dir):
                os.makedirs(model_output_dir)
            output_model_file = os.path.join(model_output_dir, WEIGHTS_NAME)
            output_config_file = os.path.join(model_output_dir, CONFIG_NAME)

            torch.save(model_to_save.state_dict(), output_model_file)
            model_to_save.config.to_json_file(output_config_file)
            tokenizer.save_vocabulary(model_output_dir)

            if args.do_eval:
                model.eval()
                eval_loss, eval_accuracy = 0, 0
                nb_eval_steps, nb_eval_examples = 0, 0

                with open(
                        os.path.join(args.output_dir,
                                     "results_" + str(epoch) + ".txt"),
                        "w") as f:
                    for input_ids, input_mask, segment_ids, label_ids, target_mask in tqdm(
                            eval_dataloader, desc="Evaluating"):
                        input_ids = input_ids.to(device)
                        input_mask = input_mask.to(device)
                        segment_ids = segment_ids.to(device)
                        label_ids = label_ids.to(device)
                        target_mask = target_mask.to(device)

                        with torch.no_grad():
                            logits = model(input_ids=input_ids,
                                           token_type_ids=segment_ids,
                                           attention_mask=input_mask,
                                           labels=None,
                                           target_mask=target_mask)

                        logits_ = F.softmax(logits, dim=-1)
                        logits_ = logits_.detach().cpu().numpy()
                        label_ids_ = label_ids.to('cpu').numpy()
                        outputs = np.argmax(logits_, axis=1)
                        for output_i in range(len(outputs)):
                            f.write(str(outputs[output_i]))
                            for ou in logits_[output_i]:
                                f.write(" " + str(ou))
                            f.write("\n")
                        tmp_eval_accuracy = np.sum(outputs == label_ids_)

                        # create eval loss and other metric required by the task
                        if output_mode == "classification":
                            loss_fct = CrossEntropyLoss()
                            tmp_eval_loss = loss_fct(
                                logits.view(-1, num_labels),
                                label_ids.view(-1))
                        elif output_mode == "regression":
                            loss_fct = MSELoss()
                            tmp_eval_loss = loss_fct(logits.view(-1),
                                                     label_ids.view(-1))

                        eval_loss += tmp_eval_loss.mean().item()
                        eval_accuracy += tmp_eval_accuracy
                        nb_eval_examples += input_ids.size(0)
                        nb_eval_steps += 1

                eval_loss = eval_loss / nb_eval_steps
                eval_accuracy = eval_accuracy / nb_eval_examples
                loss = tr_loss / nb_tr_steps if args.do_train else None

                result = OrderedDict()
                result['eval_loss'] = eval_loss
                result['eval_accuracy'] = eval_accuracy
                result['global_step'] = global_step
                result['loss'] = loss

                output_eval_file = os.path.join(args.output_dir,
                                                "eval_results.txt")
                with open(output_eval_file, "a+") as writer:
                    writer.write("epoch=%s\n" % str(epoch))
                    logger.info("***** Eval results *****")
                    for key in result.keys():
                        logger.info("  %s = %s", key, str(result[key]))
                        writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_test and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.eval_data_dir,
                                                   args.label_data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_target_mask = torch.tensor([f.target_mask for f in eval_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.float)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids,
                                  all_target_mask)
        eval_dataloader = DataLoader(eval_data,
                                     batch_size=args.eval_batch_size,
                                     shuffle=False)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0

        with open(os.path.join(args.output_dir, "results.txt"), "w") as f:
            for input_ids, input_mask, segment_ids, label_ids, target_mask in tqdm(
                    eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)
                target_mask = target_mask.to(device)

                with torch.no_grad():
                    logits = model(input_ids=input_ids,
                                   token_type_ids=segment_ids,
                                   attention_mask=input_mask,
                                   labels=None,
                                   target_mask=target_mask)

                logits_ = F.softmax(logits, dim=-1)
                logits_ = logits_.detach().cpu().numpy()
                label_ids_ = label_ids.to('cpu').numpy()
                outputs = np.argmax(logits_, axis=1)
                for output_i in range(len(outputs)):
                    f.write(str(outputs[output_i]))
                    for ou in logits_[output_i]:
                        f.write(" " + str(ou))
                    f.write("\n")
                tmp_eval_accuracy = np.sum(outputs == label_ids_)

                # create eval loss and other metric required by the task
                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                             label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    tmp_eval_loss = loss_fct(logits.view(-1),
                                             label_ids.view(-1))

                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy
                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        loss = tr_loss / nb_tr_steps if args.do_train else None

        result = OrderedDict()
        result['eval_loss'] = eval_loss
        result['eval_accuracy'] = eval_accuracy
        result['global_step'] = global_step
        result['loss'] = loss

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "a+") as writer:
            logger.info("***** Eval results *****")
            for key in result.keys():
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Example #24
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    # Data Directory
    parser.add_argument("--data_dir1",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--data_dir2",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    # Bert Model
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
                        "bert-base-multilingual-cased, bert-base-chinese.")
    # Name of Task 1
    parser.add_argument("--task1_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the first task to train.")
    # Name of Task 2
    parser.add_argument("--task2_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the second task to train")
    # Output Directory
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    # Max sequence length
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    # Train it?
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    # Run evaluation?
    parser.add_argument("--do_eval",
                        default=0,
                        type=int,
                        help="Whether to run eval on the dev set. 0: Don't eval, 1: Eval task 1, 2: Eval task 2, 3: Eval both")
    # Uncased?
    parser.add_argument("--do_lower_case",
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")

    # Set batch size for the first task
    parser.add_argument("--train_batch_size1",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    # Set batch size for the second task
    parser.add_argument("--train_batch_size2",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    # Batch size for evaluation
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    # Learning Rate for Adam
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    # Training epochs
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    # ??
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    # Select Processor
    processors = {
        "rte": RteProcessor,
        "stsb": StsbProcessor,
        "sst2": Sst2Processor,
        "qnli": QnliProcessor,
        "qqp": QqpProcessor,
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "mrpc": MrpcProcessor,
    }

    # number of labels for each task
    num_labels_task = {
        "rte": 2,
        "stsb": bin + 1,
        "sst2": 2,
        "qnli": 2,
        "qqp": 2,
        "cola": 2,
        "mnli": 3,
        "mrpc": 2,
    }

    # cuda or cpu
    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    # Check for valid args
    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    # Set train batch size
    args.train_batch_size1 = int(args.train_batch_size1 / args.gradient_accumulation_steps)
    args.train_batch_size2 = int(args.train_batch_size2 / args.gradient_accumulation_steps)

    train_batch_size = [args.train_batch_size1, args.train_batch_size2]


    # Seeds
    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and args.do_eval == 0:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    # Set task name
    task_name = [args.task1_name.lower(), args.task2_name.lower()]

    # Check if task is in processors. Will need to add to the dictionary if I plan to add another task
    for i in range(2):
        if task_name[i] not in processors:
            raise ValueError("Task %d not found: %s" % (i + 1, task_name[i]))

    # Run the processor. Will need to check what each processor does
    # Create each processor
    processor = [processors[task_name[0]](), processors[task_name[1]]()]

    # Task label
    num_labels = [num_labels_task[task_name[0]], num_labels_task[task_name[1]]]

    # List of labels
    label_list = [processor[0].get_labels(), processor[1].get_labels()]

    # Call Tokenizer
    tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)

    train_examples = []
    num_train_steps = []
    data_dir = [args.data_dir1, args.data_dir2]

    # Train ?
    if args.do_train:
        for i in range(2):
            train_examples.append(processor[i].get_train_examples(data_dir[i]))
            num_train_steps.append(int(len(train_examples[i]) / train_batch_size[i] / args.gradient_accumulation_steps * args.num_train_epochs))

    # Prepare model
    model = BertForSequenceClassification.from_pretrained(args.bert_model,
              cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
              # Do I need num_labels as a parameter?
    ## Need to modify the model in modeling.py... How to?
    mconfig = model.config
    multilayer = [GlueModel(mconfig, num_labels[i]) for i in range(2)]
    for e in multilayer:
        e.cuda()

    ## Additional Optimizers

    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer TODO
    param_optimizer = [list(model.named_parameters()) + list(multilayer[i].named_parameters()) for i in range(2)]
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [[
        {'params': [p for n, p in param_optimizer[i] if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        {'params': [p for n, p in param_optimizer[i] if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ] for i in range(2)]
    t_total = sum(num_train_steps)
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)

    else:
        # Create Optimizer
        ## Todo
        optimizer = [BertAdam(optimizer_grouped_parameters[i],
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total) for i in range(2)]

    global_step = 0
    nb_tr_steps = [0, 0]
    tr_loss = [0, 0]
    train_features = []
    train_data = []
    train_sampler = []
    train_dataloader = []
    if args.do_train:
        logger.info("***** Running training *****")
        for i in range(2):
            train_features.append(convert_examples_to_features(
                train_examples[i], label_list[i], args.max_seq_length, tokenizer))
            logger.info("  Task %d", i + 1)
            logger.info("  Num examples = %d", len(train_examples[i]))
            logger.info("  Batch size = %d", train_batch_size[i])
            logger.info("  Num steps = %d", num_train_steps[i])
            all_input_ids = torch.tensor([f.input_ids for f in train_features[i]], dtype=torch.long)
            all_input_mask = torch.tensor([f.input_mask for f in train_features[i]], dtype=torch.long)
            all_segment_ids = torch.tensor([f.segment_ids for f in train_features[i]], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in train_features[i]], dtype=torch.long)
            train_data.append(TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids))

            if args.local_rank == -1:
                train_sampler.append(RandomSampler(train_data[i]))
            else:
                train_sampler.append(DistributedSampler(train_data[i]))
            train_dataloader.append(list(DataLoader(train_data[i], sampler=train_sampler[i], batch_size=train_batch_size[i])))

        model.train() ## apply for each
        for layer in multilayer:
            layer.train()

        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = [0, 0]
            nb_tr_examples = [0, 0]
            nb_tr_steps = [0, 0]
            step = [0, 0]
            length = [len(e) for e in train_dataloader]
            for g_step, _ in enumerate(tqdm(range(0, sum(length)), desc="Iteration")):
                if not any([step[i] - length[i] for i in range(2)]):
                    break ## loop finished, added just in case
                elif step[0] == length[0]:
                    select = 1
                elif step[1] == length[1]:
                    select = 0
                else:
                    select = random.randint(0, 1)
                ## Batch size ratio is not taken into consideration

                batch = train_dataloader[select][step[select]]
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                pooled_output = model(input_ids, segment_ids, input_mask, label_ids)
                loss = multilayer[select].foward(pooled_output, label_ids)

                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss[select] += loss.item()
                nb_tr_examples[select] += input_ids.size(0)
                nb_tr_steps[select] += 1
                if (g_step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
                    for param_group in optimizer[select].param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer[select].step()
                    optimizer[select].zero_grad()
                    global_step += 1
                step[select] += 1 ## Index check


    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    layers_to_save = [multilayer[i].module if hasattr(multilayer[i], 'module') else model for i in range(2)]
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    output_layer_file = [os.path.join(args.output_dir, "pytorch_model_layer%d.bin" % i) for i in range(2)]
    if args.do_train:
        torch.save(model_to_save.state_dict(), output_model_file)
        for i in range(2):
            torch.save(layers_to_save[i].state_dict(), output_layer_file[i])

    # # Load a trained model that you have fine-tuned
    # model_state_dict = torch.load(output_model_file)
    # model = BertForSequenceClassification.from_pretrained(args.bert_model, state_dict=model_state_dict) # ...? This also needs to be modified
    #
    # # model.load_state_dict(torch.load(path))
    # #     with open(path, 'wb') as f:
    # #         torch.save(model.state_dict(), f)
    # multilayer = [GlueModel.load_state_dict(torch.load(output_layer_file[i])) for i in range(2)]
    # model.to(device)


    eval_flag = args.do_eval

    if eval_flag != 0 and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
        task_cnt = 0
        while eval_flag > 0:
            if not eval_flag & 1:
                eval_flag >>= 1
                task_cnt += 1
                continue
            eval_examples = processor[task_cnt].get_dev_examples(data_dir[task_cnt])
            eval_features = convert_examples_to_features(
                eval_examples, label_list[task_cnt], args.max_seq_length, tokenizer)
            logger.info("***** Running evaluation for Task %d*****", task_cnt + 1)
            logger.info("  Num examples = %d", len(eval_examples))
            logger.info("  Batch size = %d", args.eval_batch_size)
            all_input_ids = torch.tensor([f.input_ids for f in eval_features], dtype=torch.long)
            all_input_mask = torch.tensor([f.input_mask for f in eval_features], dtype=torch.long)
            all_segment_ids = torch.tensor([f.segment_ids for f in eval_features], dtype=torch.long)
            all_label_ids = torch.tensor([f.label_id for f in eval_features], dtype=torch.long)
            eval_data = TensorDataset(all_input_ids, all_input_mask, all_segment_ids, all_label_ids)

            # Run prediction for full data
            eval_sampler = SequentialSampler(eval_data)
            eval_dataloader = DataLoader(eval_data, sampler=eval_sampler, batch_size=args.eval_batch_size)

            model.eval()
            multilayer[task_cnt].eval()
            eval_loss, eval_accuracy = 0, 0
            nb_eval_steps, nb_eval_examples = 0, 0

            for input_ids, input_mask, segment_ids, label_ids in tqdm(eval_dataloader, desc="Evaluating"):
                input_ids = input_ids.to(device)
                input_mask = input_mask.to(device)
                segment_ids = segment_ids.to(device)
                label_ids = label_ids.to(device)

                with torch.no_grad():
                    tmp_pooled_output = model(input_ids, segment_ids, input_mask, label_ids)
                    tmp_eval_loss = multilayer[task_cnt].foward(tmp_pooled_output, label_ids)
                    pooled_output = model(input_ids, segment_ids, input_mask)
                    logits = multilayer[task_cnt].foward(pooled_output)

                logits = logits.detach().cpu().numpy()
                label_ids = label_ids.to('cpu').numpy()
                tmp_eval_accuracy = accuracy(logits, label_ids)

                eval_loss += tmp_eval_loss.mean().item()
                eval_accuracy += tmp_eval_accuracy

                nb_eval_examples += input_ids.size(0)
                nb_eval_steps += 1

            eval_loss = eval_loss / nb_eval_steps
            eval_accuracy = eval_accuracy / nb_eval_examples
            loss = tr_loss[task_cnt]/nb_tr_steps[task_cnt] if args.do_train else None
            result = {'eval_loss': eval_loss,
                      'eval_accuracy': eval_accuracy,
                      'global_step': global_step,
                      'loss': loss}

            output_eval_file = os.path.join(args.output_dir, "eval_results%d.txt" % (task_cnt + 1))
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results for Task %d*****", task_cnt)
                for key in sorted(result.keys()):
                    logger.info("  %s = %s", key, str(result[key]))
                    writer.write("%s = %s\n" % (key, str(result[key])))
            eval_flag >>= 1
            task_cnt += 1
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .csv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese."
    )
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help="The output directory where the model checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        default=False,
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        default=False,
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = read_swag_examples(os.path.join(
            args.data_dir, 'train.tsv'),
                                            is_training=True)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    # Prepare model
    model = BertForMultipleChoice.from_pretrained(
        args.bert_model,
        cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
        'distributed_{}'.format(args.local_rank),
        num_choices=10)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())

    # hack to remove pooler, which is not used
    # thus it produce None grad that break apex
    param_optimizer = [n for n in param_optimizer if 'pooler' not in n[0]]

    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      tokenizer,
                                                      args.max_seq_length,
                                                      True)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor(select_field(train_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(train_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(train_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in train_features],
                                 dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.fp16 and args.loss_scale != 1.0:
                    # rescale loss for fp16 training
                    # see https://docs.nvidia.com/deeplearning/sdk/mixed-precision-training/index.html
                    loss = loss * args.loss_scale
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

    # Save a trained model
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    torch.save(model_to_save.state_dict(), output_model_file)

    # Load a trained model that you have fine-tuned
    model_state_dict = torch.load(output_model_file)
    model = BertForMultipleChoice.from_pretrained(args.bert_model,
                                                  state_dict=model_state_dict,
                                                  num_choices=10)
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = read_swag_examples(os.path.join(
            args.data_dir, 'eval.tsv'),
                                           is_training=True)
        eval_features = convert_examples_to_features(eval_examples, tokenizer,
                                                     args.max_seq_length, True)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor(select_field(eval_features, 'input_ids'),
                                     dtype=torch.long)
        all_input_mask = torch.tensor(select_field(eval_features,
                                                   'input_mask'),
                                      dtype=torch.long)
        all_segment_ids = torch.tensor(select_field(eval_features,
                                                    'segment_ids'),
                                       dtype=torch.long)
        all_label = torch.tensor([f.label for f in eval_features],
                                 dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for input_ids, input_mask, segment_ids, label_ids in eval_dataloader:
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                tmp_eval_loss = model(input_ids, segment_ids, input_mask,
                                      label_ids)
                logits = model(input_ids, segment_ids, input_mask)

            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(logits, label_ids)

            eval_loss += tmp_eval_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples

        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'global_step': global_step,
            'loss': tr_loss / nb_tr_steps
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Example #26
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default="",
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default="",
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=50,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_test",
                        action='store_true',
                        help="Whether to run eval on the test set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--test_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for test.")
    parser.add_argument("--learning_rate",
                        default=2e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=2.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    args = parser.parse_args()

    processors = {"task_a": taskA_Processor}

    output_modes = {"task_a": "classification"}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval and not args.do_test:
        raise ValueError(
            "At least one of `do_train` or `do_eval` or `do_test` must be True."
        )

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples()
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model, cache_dir=cache_dir, num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                logits = model(input_ids, segment_ids, input_mask, labels=None)

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss()
                    loss = loss_fct(logits.view(-1, num_labels),
                                    label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())

        # Load a trained model and config that you have fine-tuned
        config = BertConfig(output_config_file)
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))
    else:
        model = BertForSequenceClassification.from_pretrained(
            args.bert_model, num_labels=num_labels)
    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        config = BertConfig.from_json_file(
            os.path.join(args.output_dir, "bert_config.json"))
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(
            torch.load(os.path.join(args.output_dir, "pytorch_model.bin")))

        model.to(device)

        eval_examples = processor.get_dev_examples()
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                         dtype=torch.long)

        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)

        model.eval()
        eval_loss = 0
        nb_eval_steps = 0

        preds = []

        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            # create eval loss and other metric required by the task
            if output_mode == "classification":
                loss_fct = CrossEntropyLoss()
                tmp_eval_loss = loss_fct(logits.view(-1, num_labels),
                                         label_ids.view(-1))

            eval_loss += tmp_eval_loss.mean().item()
            nb_eval_steps += 1
            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0],
                                     logits.detach().cpu().numpy(),
                                     axis=0)

        eval_loss = eval_loss / nb_eval_steps
        preds = preds[0]
        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)

        result = compute_metrics(task_name, preds, all_label_ids.numpy())
        loss = tr_loss / nb_tr_steps if args.do_train else None

        result['eval_loss'] = eval_loss
        result['global_step'] = global_step
        result['loss'] = loss
        result['model_spec'] = args

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))

    if args.do_test and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        config = BertConfig.from_json_file(
            os.path.join(args.output_dir, "bert_config.json"))
        model = BertForSequenceClassification(config, num_labels=num_labels)
        model.load_state_dict(
            torch.load(os.path.join(args.output_dir, "pytorch_model.bin")))

        model.to(device)

        if args.task_name == "task_a":
            test_examples = processor.get_test_examples()

        test_features = convert_examples_to_features(test_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer, output_mode)

        logger.info("***** Prediction for Test data *****")
        logger.info("  Num examples = %d", len(test_examples))
        logger.info("  Batch size = %d", args.test_batch_size)

        all_guids = [f.guid for f in test_features]
        all_input_ids = torch.tensor([f.input_ids for f in test_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in test_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in test_features],
                                       dtype=torch.long)

        test_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids)

        # Run prediction for full data
        test_dataloader = DataLoader(test_data,
                                     batch_size=args.test_batch_size)

        model.eval()
        ids = all_guids
        preds = []

        for input_ids, input_mask, segment_ids in tqdm(test_dataloader,
                                                       desc="Prediction"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask, labels=None)

            # create eval loss and other metric required by the task

            if len(preds) == 0:
                preds.append(logits.detach().cpu().numpy())
            else:
                preds[0] = np.append(preds[0],
                                     logits.detach().cpu().numpy(),
                                     axis=0)

        preds = preds[0]

        if output_mode == "classification":
            preds = np.argmax(preds, axis=1)

        df_preds = pd.DataFrame(preds)

        if args.task_name == "task_a":
            df_preds = df_preds.replace(
                [0, 1, 2, 3], ["support", "deny", "query", "comment"])
            preds_list = list(df_preds[0])

        result = {"subtaskaenglish": {}, "subtaskbenglish": {}}

        for i, id in enumerate(ids):
            result["subtaskaenglish"][str(id)] = preds_list[i]

        # task_b labels to randomly
        tmp_1 = get_all_sources_ids('test', 'twitter')
        tmp_2 = get_all_sources_ids('test', 'reddit')
        task_b_ids = tmp_1 + tmp_2
        scores = np.random.rand(len(task_b_ids))
        for i, id in enumerate(task_b_ids):
            if scores[i] < 0.33:
                result["subtaskbenglish"][str(id)] = ["false", scores[i]]
            elif scores[i] >= 0.33 and scores[i] <= 0.66:
                result["subtaskbenglish"][str(id)] = ["unverified", 0.5]
            else:
                result["subtaskbenglish"][str(id)] = ["true", scores[i]]

        json.dump(
            result,
            open(
                os.path.join(args.output_dir,
                             'test_result_%s.json' % args.task_name), "w"))

        logger.info("***** Prediction for Test data is completed *****")
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=100,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3,
                        type=int,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    # parser.add_argument('--gpuid', type=int, default=-1,help='The gpu id to use')
    args = parser.parse_args()

    processors = {
        "cola": ColaProcessor,
        "mnli": MnliProcessor,
        "snli": SnliProcessor,
        "mrpc": MrpcProcessor,
        "sst": SstProcessor,
        "twitter": TwitterProcessor,
    }

    num_labels_task = {
        "cola": 2,
        "mnli": 3,
        "snli": 3,
        "mrpc": 2,
        "sst": 2,
        "twitter": 2,
    }

    if args.local_rank == -1 or args.no_cuda:
        if not args.no_cuda:
            # device = torch.device("cuda",args.gpuid)
            # torch.cuda.set_device(args.gpuid)
            dummy = torch.cuda.FloatTensor(1)
        else:
            device = torch.device("cpu")
        n_gpu = 1
        # n_gpu = torch.cuda.device_count()
    else:
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
        if args.fp16:
            # logger.info("16-bits training currently not supported in distributed training")
            args.fp16 = False  # (see https://github.com/pytorch/pytorch/pull/13496)
    # logger.info("device %s n_gpu %d distributed training %r", device, n_gpu, bool(args.local_rank != -1))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size /
                                args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    # if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
    #     raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    num_labels = num_labels_task[task_name]
    label_list = processor.get_labels()

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps * args.num_train_epochs)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        # logger.info("***** Running evaluation *****")
        # logger.info("  Num examples = %d", len(eval_examples))
        # logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.train_batch_size)

    # Prepare model
    model = BertForSequenceClassification.from_pretrained(
        args.bert_model,
        cache_dir=PYTORCH_PRETRAINED_BERT_CACHE /
        'distributed_{}'.format(args.local_rank),
        num_labels=num_labels)

    if args.fp16:
        model.half()
    model.cuda()
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        # logger.info("***** Running training *****")
        # logger.info("  Num examples = %d", len(train_examples))
        # logger.info("  Batch size = %d", args.train_batch_size)
        # logger.info("  Num steps = %d", num_train_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        best_eval_acc = 0.0
        for epoch in trange(int(args.num_train_epochs), desc="Epoch"):
            # for epoch in range(int(args.num_train_epochs)):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                # for step, batch in enumerate(train_dataloader):
                model.train()
                batch = tuple(t.cuda() for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    lr_this_step = args.learning_rate * warmup_linear(
                        global_step / t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                    # if epoch>0:
                    # logits_all,eval_accuracy=do_evaluation(model,eval_dataloader,args,is_training=False)
                    # print(eval_accuracy)
            # logits_all,eval_accuracy=do_evaluation(model,eval_dataloader,args,is_training=False)
            # if best_eval_acc<eval_accuracy:
            #     best_eval_acc=eval_accuracy
            #     print(eval_accuracy)
            model_save_dir = os.path.join(args.output_dir, f'model{epoch}')
            os.makedirs(model_save_dir, exist_ok=True)
            torch.save(model.state_dict(),
                       os.path.join(model_save_dir, f"pytorch_model.bin"))
        print('Best eval acc:', best_eval_acc)
Example #28
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )

    ## Other parameters
    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {"ner": NerProcessor}

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(
            args.output_dir) and args.do_train:
        raise ValueError(
            "Output directory ({}) already exists and is not empty.".format(
                args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    processor = processors[task_name]()
    label_list = processor.get_labels()
    num_labels = len(label_list) + 1

    tokenizer = BertTokenizer.from_pretrained(args.bert_model,
                                              do_lower_case=args.do_lower_case)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    model = BertForTokenClassification.from_pretrained(args.bert_model,
                                                       cache_dir=cache_dir,
                                                       num_labels=num_labels)
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss = 0
    if args.do_train:
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in train_features],
                                     dtype=torch.long)
        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch
                loss = model(input_ids, segment_ids, input_mask, label_ids)
                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

        # Save a trained model and the associated configuration
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        torch.save(model_to_save.state_dict(), output_model_file)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        with open(output_config_file, 'w') as f:
            f.write(model_to_save.config.to_json_string())
        label_map = {i: label for i, label in enumerate(label_list, 1)}
        model_config = {
            "bert_model": args.bert_model,
            "do_lower": args.do_lower_case,
            "max_seq_length": args.max_seq_length,
            "num_labels": len(label_list) + 1,
            "label_map": label_map
        }
        json.dump(
            model_config,
            open(os.path.join(args.output_dir, "model_config.json"), "w"))
        # Load a trained model and config that you have fine-tuned
    else:
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        config = BertConfig(output_config_file)
        model = BertForTokenClassification(config, num_labels=num_labels)
        model.load_state_dict(torch.load(output_model_file))

    model.to(device)

    if args.do_eval and (args.local_rank == -1
                         or torch.distributed.get_rank() == 0):
        eval_examples = processor.get_dev_examples(args.data_dir)
        eval_features = convert_examples_to_features(eval_examples, label_list,
                                                     args.max_seq_length,
                                                     tokenizer)
        logger.info("***** Running evaluation *****")
        logger.info("  Num examples = %d", len(eval_examples))
        logger.info("  Batch size = %d", args.eval_batch_size)
        all_input_ids = torch.tensor([f.input_ids for f in eval_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in eval_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in eval_features],
                                       dtype=torch.long)
        all_label_ids = torch.tensor([f.label_id for f in eval_features],
                                     dtype=torch.long)
        eval_data = TensorDataset(all_input_ids, all_input_mask,
                                  all_segment_ids, all_label_ids)
        # Run prediction for full data
        eval_sampler = SequentialSampler(eval_data)
        eval_dataloader = DataLoader(eval_data,
                                     sampler=eval_sampler,
                                     batch_size=args.eval_batch_size)
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        y_true = []
        y_pred = []
        label_map = {i: label for i, label in enumerate(label_list, 1)}
        for input_ids, input_mask, segment_ids, label_ids in tqdm(
                eval_dataloader, desc="Evaluating"):
            input_ids = input_ids.to(device)
            input_mask = input_mask.to(device)
            segment_ids = segment_ids.to(device)
            label_ids = label_ids.to(device)

            with torch.no_grad():
                logits = model(input_ids, segment_ids, input_mask)

            logits = torch.argmax(F.log_softmax(logits, dim=2), dim=2)
            logits = logits.detach().cpu().numpy()
            label_ids = label_ids.to('cpu').numpy()
            input_mask = input_mask.to('cpu').numpy()
            for i, mask in enumerate(input_mask):
                temp_1 = []
                temp_2 = []
                for j, m in enumerate(mask):
                    if j == 0:
                        continue
                    if m:
                        if label_map[label_ids[i][j]] != "X":
                            temp_1.append(label_map[label_ids[i][j]])
                            temp_2.append(label_map[logits[i][j]])
                    else:
                        temp_1.pop()
                        temp_2.pop()
                        break
                y_true.append(temp_1)
                y_pred.append(temp_2)
        report = classification_report(y_true, y_pred, digits=4)
        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            logger.info("\n%s", report)
            writer.write(report)
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument("--data_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The input data dir. Should contain the .tsv files (or other data files) for the task.")
    parser.add_argument("--bert_model", default=None, type=str, required=True,
                        help="Bert pre-trained model selected in the list: bert-base-uncased, "
                             "bert-large-uncased, bert-base-cased, bert-base-multilingual, bert-base-chinese.")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument("--output_dir",
                        default=None,
                        type=str,
                        required=True,
                        help="The output directory where the model predictions and checkpoints will be written.")

    ## Other parameters
    parser.add_argument("--max_seq_length",
                        default=128,
                        type=int,
                        help="The maximum total input sequence length after WordPiece tokenization. \n"
                             "Sequences longer than this will be truncated, and sequences shorter \n"
                             "than this will be padded.")
    parser.add_argument("--do_train",
                        default=False,
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        default=False,
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument("--do_lower_case",
                        default=False,
                        action='store_true',
                        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=8,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument("--warmup_proportion",
                        default=0.1,
                        type=float,
                        help="Proportion of training to perform linear learning rate warmup for. "
                             "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        default=False,
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before performing a backward/update pass.")
    parser.add_argument('--fp16',
                        default=False,
                        action='store_true',
                        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument('--loss_scale',
                        type=float, default=0,
                        help="Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
                             "0 (default value): dynamic loss scaling.\n"
                             "Positive power of 2: static loss scaling value.\n")

    args = parser.parse_args()

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')
    logger.info("device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".format(
        device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError("Invalid gradient_accumulation_steps parameter: {}, should be >= 1".format(
                            args.gradient_accumulation_steps))

    args.train_batch_size = int(args.train_batch_size / args.gradient_accumulation_steps)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError("At least one of `do_train` or `do_eval` must be True.")

    if os.path.exists(args.output_dir) and os.listdir(args.output_dir):
        raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    os.makedirs(args.output_dir, exist_ok=True)

    task_name = args.task_name.lower()
    
    vecs = []
    vecs.append([0]*200) # 扩充CLS的位置,其他所有索引向后+1.
    with open("config_data/kg_embed/entity2vec.vec", 'r') as fin:
    #with open("pretrain_data/config_data/entity2vec.vec", 'r') as fin:
        for line in fin:
            vec = line.strip().split('\t')
            #vec = [float(x) for x in vec if x != ""]
            vec = [float(x) for x in vec]
            vecs.append(vec)
    print("vecs_len=%s" % str(len(vecs)))
    print("vecs_dim=%s" % str(len(vecs[0])))
    ent_embed = torch.FloatTensor(vecs)
    ent_embed = torch.nn.Embedding.from_pretrained(ent_embed)
    #ent_embed = torch.nn.Embedding(5041175, 100)

    logger.info("Shape of entity embedding: "+str(ent_embed.weight.size()))

    vecs = []
    vecs.append([0] * 4096)  # 扩充CLS的位置,其他所有索引向后+1.
    with open("config_data/kg_embed/image2vec.vec", 'r') as fin:
    #with open("pretrain_data/image_vec/image2vec.vec", 'r') as fin:
        for line in fin:
            vec = line.strip().split('\t')
            vec = [float(x) for x in vec]
            vecs.append(vec)
    print("vecs_len=%s" % str(len(vecs)))
    print("vecs_dim=%s" % str(len(vecs[0])))
    img_embed = torch.FloatTensor(vecs)
    img_embed = torch.nn.Embedding.from_pretrained(img_embed)

    logger.info("Shape of image embedding: " + str(img_embed.weight.size()))
    del vecs

    train_data = None
    num_train_steps = None
    if args.do_train:
        # TODO
        import indexed_dataset
        from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler,BatchSampler
        import iterators
        #train_data = indexed_dataset.IndexedCachedDataset(args.data_dir)
        train_data = indexed_dataset.IndexedDataset(args.data_dir, fix_lua_indexing=True)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_sampler = BatchSampler(train_sampler, args.train_batch_size, True)

        def collate_fn(x):
            x = torch.LongTensor([xx for xx in x])

            entity_idx = x[:, 4 * args.max_seq_length:5 * args.max_seq_length]
            print("entity_idx=%s" % entity_idx)
            image_idx = x[:, 6 * args.max_seq_length:7 * args.max_seq_length]
            print("image_idx=%s" % image_idx)
            # Build candidate
            ent_uniq_idx = np.unique(entity_idx.numpy())
            print("ent_uniq_idx=%s" % str(ent_uniq_idx))
            img_uniq_idx = np.unique(image_idx.numpy())
            print("img_uniq_idx=%s" % str(img_uniq_idx))
            ent_candidate = ent_embed(torch.LongTensor(ent_uniq_idx + 1))
            ent_candidate = ent_candidate.repeat([n_gpu, 1])
            img_candidate = img_embed(torch.LongTensor(img_uniq_idx + 1))
            img_candidate = img_candidate.repeat([n_gpu, 1])
            # build entity labels
            ent_idx_dict = {}
            ent_idx_list = []
            for idx, idx_value in enumerate(ent_uniq_idx):
                ent_idx_dict[idx_value] = idx
                ent_idx_list.append(idx_value)
            ent_size = len(ent_uniq_idx)-1
            # build image labels
            img_idx_dict = {}
            img_idx_list = []
            for idx, idx_value in enumerate(img_uniq_idx):
                img_idx_dict[idx_value] = idx
                img_idx_list.append(idx_value)
            img_size = len(img_uniq_idx) - 1

            def ent_map(x):
                if x == -1:
                    return -1
                else:
                    rnd = random.uniform(0, 1)
                    if rnd < 0.05:
                        return ent_idx_list[random.randint(1, ent_size)]
                    elif rnd < 0.2:
                        return -1
                    else:
                        return x

            def img_map(x):
                if x == -1:
                    return -1
                else:
                    rnd = random.uniform(0, 1)
                    if rnd < 0.05:
                        return img_idx_list[random.randint(1, ent_size)]
                    elif rnd < 0.2:
                        return -1
                    else:
                        return x

            ent_labels = entity_idx.clone()
            ent_idx_dict[-1] = -1
            ent_labels = ent_labels.apply_(lambda x: ent_idx_dict[x])

            entity_idx.apply_(ent_map)
            ent_emb = ent_embed(entity_idx+1)
            ent_mask = entity_idx.clone()
            ent_mask.apply_(lambda x: 0 if x == -1 else 1)
            ent_mask[:,0] = 1

            img_labels = image_idx.clone()
            img_idx_dict[-1] = -1
            img_labels = img_labels.apply_(lambda x: img_idx_dict[x])

            image_idx.apply_(img_map)
            img_emb = img_embed(image_idx + 1)
            img_mask = image_idx.clone()
            img_mask.apply_(lambda x: 0 if x == -1 else 1)
            img_mask[:, 0] = 1

            input_ids = x[:,:args.max_seq_length]
            input_mask = x[:,args.max_seq_length:2*args.max_seq_length]
            segment_ids = x[:,2*args.max_seq_length:3*args.max_seq_length]
            masked_lm_labels = x[:,3*args.max_seq_length:4*args.max_seq_length]
            next_sentence_label = x[:,8*args.max_seq_length:]
            return input_ids, input_mask, segment_ids, masked_lm_labels, ent_emb, ent_mask, img_emb, img_mask, next_sentence_label, ent_candidate, ent_labels, img_candidate, img_labels

        train_iterator = iterators.EpochBatchIterator(train_data, collate_fn, train_sampler)
        num_train_steps = int(
            len(train_data) / args.train_batch_size / args.gradient_accumulation_steps * args.num_train_epochs)

    print ("len(train_data)=%s" % len(train_data))
    # Prepare model
    model, missing_keys = BertForPreTraining.from_pretrained(args.bert_model,
              cache_dir=PYTORCH_PRETRAINED_BERT_CACHE / 'distributed_{}'.format(args.local_rank))
    if args.fp16:
        model.half()
    model.to(device)
    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        model = DDP(model)
    elif n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    print ("param_optimizer:")
    #for param in model.named_parameters():
    #    print(param[0])

    #no_linear = ['layer.2.output.dense_ent', 'layer.2.intermediate.dense_1', 'bert.encoder.layer.2.intermediate.dense_1_ent', 'layer.2.output.LayerNorm_ent']
    #no_linear = [x.replace('2', '11') for x in no_linear]
    no_linear = ['layer.11.output.dense_entity', 'layer.11.output.LayerNorm_entity', 'layer.11.output.dense_image', 'layer.11.output.LayerNorm_entity']
    param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in no_linear)]
    print ("param_optimizer--no_linear")
    #for param in param_optimizer:
    #    print (param[0])

    #param_optimizer = [(n, p) for n, p in param_optimizer if not any(nl in n for nl in missing_keys)]
    #no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm_ent.bias', 'LayerNorm_ent.weight']
    #no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm_ent.bias', 'LayerNorm_ent.weight']
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight', 'LayerNorm_token.bias', 'LayerNorm_token.weight', 'LayerNorm_entity.bias', 'LayerNorm_entity.weight', 'LayerNorm_image.bias', 'LayerNorm_image.weight']
    optimizer_grouped_parameters = [
        # weight decay to avoid overfitting 
        # source: https://blog.csdn.net/program_developer/article/details/80867468
        # source: https://blog.csdn.net/m0_37531129/article/details/101390592
        {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
        # the decay of bias and normalization.weight has nothing to do with weight decay
        {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
        ]
    # optimizer_grouped_parameters_display is only used to debug
#    optimizer_grouped_parameters_display = [
#        {'params': [(n,p) for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},
#        {'params': [(n,p) for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
#        ]
#    print ("optimizer_grouped_parameters_display-0:")
#    for param in optimizer_grouped_parameters_display[0]['params']:
#        print (param[0])
#
#    print ("optimizer_grouped_parameters_display-1:")
#    for param in optimizer_grouped_parameters_display[1]['params']:
#        print (param[0])

    t_total = num_train_steps
    if args.local_rank != -1:
        t_total = t_total // torch.distributed.get_world_size()
    if args.fp16:
        try:
            #from apex.optimizers import FP16_Optimizer
            from apex.fp16_utils.fp16_optimizer import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training.")

        #optimizer = FusedAdam(optimizer_grouped_parameters,
        #                      lr=args.learning_rate,
        #                      bias_correction=False,
        #                      max_grad_norm=1.0)
        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer, static_loss_scale=args.loss_scale)
        #logger.info(dir(optimizer))
        #op_path = os.path.join(args.bert_model, "pytorch_op.bin")
        #optimizer.load_state_dict(torch.load(op_path))

    else:
        optimizer = BertAdam(optimizer_grouped_parameters,
                             lr=args.learning_rate,
                             warmup=args.warmup_proportion,
                             t_total=t_total)

    global_step = 0
    if args.do_train:
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_data))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_steps)

        model.train()
        import datetime
        fout = open(os.path.join(args.output_dir, "loss.{}".format(datetime.datetime.now())), 'w')
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(tqdm(train_iterator.next_epoch_itr(), desc="Iteration")):
                print ("step=%s" % str(step))
                print ("len(batch)=%s" % str(len(batch)))
                batch = tuple(t.to(device) for t in batch)

                input_ids, input_mask, segment_ids, masked_lm_labels, input_ent, ent_mask, input_img, img_mask, next_sentence_label, ent_candidate, ent_labels, img_candidate, img_labels = batch
                print ("\ninput_ids.size=%s" % str(input_ids.size()))
                print ("input_mask.size=%s" % str(input_mask.size()))
                print ("segment_ids.size=%s" % str(segment_ids.size()))
                print ("masked_lm_labels.size=%s" % str(masked_lm_labels.size()))
                print ("input_ent.size=%s" % str(input_ent.size()))
                print ("ent_mask.size=%s" % str(ent_mask.size()))
                print ("input_img.size=%s" % str(input_img.size()))
                print ("img_mask.size=%s" % str(img_mask.size()))
                print ("next_sentence_label.size=%s" % str(next_sentence_label.size()))
                print ("ent_candidate.size=%s" % str(ent_candidate.size()))
                print ("ent_labels.size=%s" % str(ent_labels.size()))
                print ("img_candidate.size=%s" % str(img_candidate.size()))
                print ("img_labels.size=%s" % str(img_labels.size()))

                if args.fp16:
                    loss, original_loss = model(input_ids, segment_ids, input_mask, masked_lm_labels,
                                                input_ent.half(), ent_mask, input_img.half(), img_mask,
                                                next_sentence_label, ent_candidate.half(), ent_labels,
                                                img_candidate.half(), img_labels)
                else:
                    loss, original_loss = model(input_ids, segment_ids, input_mask, masked_lm_labels,
                                                input_ent, ent_mask, input_img, img_mask,
                                                next_sentence_label, ent_candidate, ent_labels,
                                                img_candidate, img_labels)


                if n_gpu > 1:
                    loss = loss.mean() # mean() to average on multi-gpu.
                    original_loss = original_loss.mean()
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                print("\nloss=%s\n" % str(loss))

                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                fout.write("{} {}\n".format(loss.item()*args.gradient_accumulation_steps, original_loss.item()))
                tr_loss += loss.item()
                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    # modify learning rate with special warm up BERT uses
                    # source: https://blog.csdn.net/m0_37531129/article/details/101390592
                    lr_this_step = args.learning_rate * warmup_linear(global_step/t_total, args.warmup_proportion)
                    for param_group in optimizer.param_groups:
                        param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1
                    #if global_step % 1000 == 0:
                    #    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
                    #    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin_{}".format(global_step))
                    #    torch.save(model_to_save.state_dict(), output_model_file)
        fout.close()

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    torch.save(model_to_save.state_dict(), output_model_file)
Example #30
0
def main():
    parser = argparse.ArgumentParser()

    ## Required parameters
    parser.add_argument(
        "--data_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The input data dir. Should contain the .tsv files (or other data files) for the task."
    )
    parser.add_argument(
        "--bert_model",
        default=None,
        type=str,
        required=True,
        help="Bert pre-trained model selected in the list: bert-base-uncased, "
        "bert-large-uncased, bert-base-cased, bert-large-cased, bert-base-multilingual-uncased, "
        "bert-base-multilingual-cased, bert-base-chinese,"
        "roberta-base, xlm-mlm-ende-1024")
    parser.add_argument("--task_name",
                        default=None,
                        type=str,
                        required=True,
                        help="The name of the task to train.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument("--negative_weight", default=1., type=float)
    parser.add_argument("--neutral_words_file", default='data/identity.csv')

    # if true, use test data instead of val data
    parser.add_argument("--test", action='store_true')

    # Explanation specific arguments below

    # whether run explanation algorithms
    parser.add_argument("--explain",
                        action='store_true',
                        help='if true, explain test set predictions')
    parser.add_argument("--debug", action='store_true')

    # which algorithm to run
    parser.add_argument("--algo", choices=['soc'])

    # the output filename without postfix
    parser.add_argument("--output_filename", default='temp.tmp')

    # see utils/config.py
    parser.add_argument("--use_padding_variant", action='store_true')
    parser.add_argument("--mask_outside_nb", action='store_true')
    parser.add_argument("--nb_range", type=int)
    parser.add_argument("--sample_n", type=int)

    # whether use explanation regularization
    parser.add_argument("--reg_explanations", action='store_true')
    parser.add_argument("--reg_strength", type=float)
    parser.add_argument("--reg_mse", action='store_true')

    # whether discard other neutral words during regularization. default: False
    parser.add_argument("--discard_other_nw",
                        action='store_false',
                        dest='keep_other_nw')

    # whether remove neutral words when loading datasets
    parser.add_argument("--remove_nw", action='store_true')

    # if true, generate hierarchical explanations instead of word level outputs.
    # Only useful when the --explain flag is also added.
    parser.add_argument("--hiex", action='store_true')
    parser.add_argument("--hiex_tree_height", default=5, type=int)

    # whether add the sentence itself to the sample set in SOC
    parser.add_argument("--hiex_add_itself", action='store_true')

    # the directory where the lm is stored
    parser.add_argument("--lm_dir", default='runs/lm')

    # if configured, only generate explanations for instances with given line numbers
    parser.add_argument("--hiex_idxs", default=None)
    # if true, use absolute values of explanations for hierarchical clustering
    parser.add_argument("--hiex_abs", action='store_true')

    # if either of the two is true, only generate explanations for positive / negative instances
    parser.add_argument("--only_positive", action='store_true')
    parser.add_argument("--only_negative", action='store_true')

    # stop after generating x explanation
    parser.add_argument("--stop", default=100000000, type=int)

    # early stopping with decreasing learning rate. 0: direct exit when validation F1 decreases
    parser.add_argument("--early_stop", default=5, type=int)

    # other external arguments originally here in pytorch_transformers

    parser.add_argument(
        "--cache_dir",
        default="",
        type=str,
        help=
        "Where do you want to store the pre-trained models downloaded from s3")
    parser.add_argument(
        "--max_seq_length",
        default=128,
        type=int,
        help=
        "The maximum total input sequence length after WordPiece tokenization. \n"
        "Sequences longer than this will be truncated, and sequences shorter \n"
        "than this will be padded.")
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--do_lower_case",
        action='store_true',
        help="Set this flag if you are using an uncased model.")
    parser.add_argument("--train_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for training.")
    parser.add_argument("--eval_batch_size",
                        default=32,
                        type=int,
                        help="Total batch size for eval.")
    parser.add_argument("--validate_steps",
                        default=200,
                        type=int,
                        help="validate once for how many steps")
    parser.add_argument("--learning_rate",
                        default=5e-5,
                        type=float,
                        help="The initial learning rate for Adam.")
    parser.add_argument("--num_train_epochs",
                        default=3.0,
                        type=float,
                        help="Total number of training epochs to perform.")
    parser.add_argument(
        "--warmup_proportion",
        default=0.1,
        type=float,
        help=
        "Proportion of training to perform linear learning rate warmup for. "
        "E.g., 0.1 = 10%% of training.")
    parser.add_argument("--no_cuda",
                        action='store_true',
                        help="Whether not to use CUDA when available")
    parser.add_argument("--local_rank",
                        type=int,
                        default=-1,
                        help="local_rank for distributed training on gpus")
    parser.add_argument('--seed',
                        type=int,
                        default=42,
                        help="random seed for initialization")
    parser.add_argument(
        '--gradient_accumulation_steps',
        type=int,
        default=1,
        help=
        "Number of updates steps to accumulate before performing a backward/update pass."
    )
    parser.add_argument(
        '--fp16',
        action='store_true',
        help="Whether to use 16-bit float precision instead of 32-bit")
    parser.add_argument(
        '--loss_scale',
        type=float,
        default=0,
        help=
        "Loss scaling to improve fp16 numeric stability. Only used when fp16 set to True.\n"
        "0 (default value): dynamic loss scaling.\n"
        "Positive power of 2: static loss scaling value.\n")
    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    combine_args(configs, args)
    args = configs

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    processors = {
        'gab': GabProcessor,
        'ws': WSProcessor,
        'nyt': NytProcessor,
        #'multi-class': multiclass_Processor,
        #'multi-label': multilabel_Processor,
    }

    output_modes = {
        'gab': 'classification',
        'ws': 'classification',
        'nyt': 'classification'
    }

    if args.local_rank == -1 or args.no_cuda:
        device = torch.device("cuda" if torch.cuda.is_available()
                              and not args.no_cuda else "cpu")
        n_gpu = torch.cuda.device_count()
    else:
        torch.cuda.set_device(args.local_rank)
        device = torch.device("cuda", args.local_rank)
        n_gpu = 1
        # Initializes the distributed backend which will take care of sychronizing nodes/GPUs
        torch.distributed.init_process_group(backend='nccl')

    logging.basicConfig(
        format='%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
        datefmt='%m/%d/%Y %H:%M:%S',
        level=logging.INFO if args.local_rank in [-1, 0] else logging.WARN)

    logger.info(
        "device: {} n_gpu: {}, distributed training: {}, 16-bits training: {}".
        format(device, n_gpu, bool(args.local_rank != -1), args.fp16))

    if args.gradient_accumulation_steps < 1:
        raise ValueError(
            "Invalid gradient_accumulation_steps parameter: {}, should be >= 1"
            .format(args.gradient_accumulation_steps))

    args.train_batch_size = args.train_batch_size // args.gradient_accumulation_steps

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    if n_gpu > 0:
        torch.cuda.manual_seed_all(args.seed)

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    #if os.path.exists(args.output_dir) and os.listdir(args.output_dir) and args.do_train:
    #    raise ValueError("Output directory ({}) already exists and is not empty.".format(args.output_dir))
    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # save configs
    f = open(os.path.join(args.output_dir, 'args.json'), 'w')
    json.dump(args.__dict__, f, indent=4)
    f.close()

    task_name = args.task_name.lower()

    if task_name not in processors:
        raise ValueError("Task not found: %s" % (task_name))

    #tokenizer = BertTokenizer.from_pretrained(args.bert_model, do_lower_case=args.do_lower_case)
    tokenizer = XLMTokenizer.from_pretrained(args.bert_model,
                                             do_lower_case=args.do_lower_case)
    processor = processors[task_name](configs, tokenizer=tokenizer)
    output_mode = output_modes[task_name]

    label_list = processor.get_labels()
    num_labels = len(label_list)

    train_examples = None
    num_train_optimization_steps = None
    if args.do_train:
        train_examples = processor.get_train_examples(args.data_dir)
        num_train_optimization_steps = int(
            len(train_examples) / args.train_batch_size /
            args.gradient_accumulation_steps) * args.num_train_epochs
        if args.local_rank != -1:
            num_train_optimization_steps = num_train_optimization_steps // torch.distributed.get_world_size(
            )

    # Prepare model
    cache_dir = args.cache_dir if args.cache_dir else os.path.join(
        str(PYTORCH_PRETRAINED_BERT_CACHE), 'distributed_{}'.format(
            args.local_rank))
    if args.do_train:
        model = XLMForSequenceClassification.from_pretrained(
            args.bert_model, cache_dir=cache_dir, num_labels=num_labels)

    else:
        model = XLMForSequenceClassification.from_pretrained(
            args.output_dir, num_labels=num_labels)
    model.to(device)

    if args.fp16:
        model.half()

    if args.local_rank != -1:
        try:
            from apex.parallel import DistributedDataParallel as DDP
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        model = DDP(model)
    # elif n_gpu > 1:
    #     model = torch.nn.DataParallel(model)

    # Prepare optimizer
    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    if args.fp16:
        try:
            from apex.optimizers import FP16_Optimizer
            from apex.optimizers import FusedAdam
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use distributed and fp16 training."
            )

        optimizer = FusedAdam(optimizer_grouped_parameters,
                              lr=args.learning_rate,
                              bias_correction=False,
                              max_grad_norm=1.0)
        if args.loss_scale == 0:
            optimizer = FP16_Optimizer(optimizer, dynamic_loss_scale=True)
        else:
            optimizer = FP16_Optimizer(optimizer,
                                       static_loss_scale=args.loss_scale)
        warmup_linear = WarmupLinearSchedule(
            warmup=args.warmup_proportion,
            t_total=num_train_optimization_steps)

    else:
        if args.do_train:
            optimizer = BertAdam(optimizer_grouped_parameters,
                                 lr=args.learning_rate,
                                 warmup=args.warmup_proportion,
                                 t_total=num_train_optimization_steps)

    global_step = 0
    nb_tr_steps = 0
    tr_loss, tr_reg_loss = 0, 0
    tr_reg_cnt = 0
    epoch = -1
    val_best_f1 = -1
    val_best_loss = 1e10
    early_stop_countdown = args.early_stop

    if args.reg_explanations:
        train_lm_dataloder = processor.get_dataloader('train',
                                                      configs.train_batch_size)
        dev_lm_dataloader = processor.get_dataloader('dev',
                                                     configs.train_batch_size)
        explainer = SamplingAndOcclusionExplain(
            model,
            configs,
            tokenizer,
            device=device,
            vocab=tokenizer.vocab,
            train_dataloader=train_lm_dataloder,
            dev_dataloader=dev_lm_dataloader,
            lm_dir=args.lm_dir,
            output_path=os.path.join(configs.output_dir,
                                     configs.output_filename),
        )
    else:
        explainer = None

    if args.do_train:
        epoch = 0
        train_features = convert_examples_to_features(train_examples,
                                                      label_list,
                                                      args.max_seq_length,
                                                      tokenizer, output_mode,
                                                      configs)
        logger.info("***** Running training *****")
        logger.info("  Num examples = %d", len(train_examples))
        logger.info("  Batch size = %d", args.train_batch_size)
        logger.info("  Num steps = %d", num_train_optimization_steps)
        all_input_ids = torch.tensor([f.input_ids for f in train_features],
                                     dtype=torch.long)
        all_input_mask = torch.tensor([f.input_mask for f in train_features],
                                      dtype=torch.long)
        all_segment_ids = torch.tensor([f.segment_ids for f in train_features],
                                       dtype=torch.long)

        if output_mode == "classification":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.long)
        elif output_mode == "regression":
            all_label_ids = torch.tensor([f.label_id for f in train_features],
                                         dtype=torch.float)

        train_data = TensorDataset(all_input_ids, all_input_mask,
                                   all_segment_ids, all_label_ids)
        if args.local_rank == -1:
            train_sampler = RandomSampler(train_data)
        else:
            train_sampler = DistributedSampler(train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)

        class_weight = torch.FloatTensor([args.negative_weight, 1]).to(device)

        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_examples, nb_tr_steps = 0, 0
            for step, batch in enumerate(
                    tqdm(train_dataloader, desc="Iteration")):
                batch = tuple(t.to(device) for t in batch)
                input_ids, input_mask, segment_ids, label_ids = batch

                # define a new function to compute loss values for both output_modes
                outputs = model(input_ids, input_mask, labels=None)
                # print('outputs len', len(outputs))
                # # 1
                # print('outputs 0 size', outputs[0].size())
                # # [32, 2]
                logits = outputs[0]
                #                 hidden_states = outputs.hidden_states
                #                 attentions = outputs.attentions

                if output_mode == "classification":
                    loss_fct = CrossEntropyLoss(class_weight)
                    loss = loss_fct(logits.view(-1, num_labels),
                                    label_ids.view(-1))
                elif output_mode == "regression":
                    loss_fct = MSELoss()
                    loss = loss_fct(logits.view(-1), label_ids.view(-1))

                if n_gpu > 1:
                    loss = loss.mean()  # mean() to average on multi-gpu.
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps
                tr_loss += loss.item()
                if args.fp16:
                    optimizer.backward(loss)
                else:
                    loss.backward()

                # regularize explanations
                # NOTE: backward performed inside this function to prevent OOM

                if args.reg_explanations:
                    reg_loss, reg_cnt = explainer.compute_explanation_loss(
                        input_ids, input_mask, label_ids, do_backprop=True)
                    tr_reg_loss += reg_loss  # float
                    tr_reg_cnt += reg_cnt

                nb_tr_examples += input_ids.size(0)
                nb_tr_steps += 1
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    if args.fp16:
                        # modify learning rate with special warm up BERT uses
                        # if args.fp16 is False, BertAdam is used that handles this automatically
                        lr_this_step = args.learning_rate * warmup_linear.get_lr(
                            global_step / num_train_optimization_steps,
                            args.warmup_proportion)
                        for param_group in optimizer.param_groups:
                            param_group['lr'] = lr_this_step
                    optimizer.step()
                    optimizer.zero_grad()
                    global_step += 1

                if global_step % args.validate_steps == 0:
                    val_result = validate(args, model, processor, tokenizer,
                                          output_mode, label_list, device,
                                          num_labels, task_name, tr_loss,
                                          global_step, epoch, explainer)
                    val_acc, val_f1 = val_result['acc'], val_result['f1']
                    if val_f1 > val_best_f1:
                        val_best_f1 = val_f1
                        if args.local_rank == -1 or torch.distributed.get_rank(
                        ) == 0:
                            save_model(args, model, tokenizer, num_labels)
                    else:
                        # halve the learning rate
                        for param_group in optimizer.param_groups:
                            param_group['lr'] *= 0.5
                        early_stop_countdown -= 1
                        logger.info(
                            "Reducing learning rate... Early stop countdown %d"
                            % early_stop_countdown)
                    if early_stop_countdown < 0:
                        break
            if early_stop_countdown < 0:
                break
            epoch += 1

            # training finish ############################

    # if args.do_eval and (args.local_rank == -1 or torch.distributed.get_rank() == 0):
    #     if not args.explain:
    #         args.test = True
    #         validate(args, model, processor, tokenizer, output_mode, label_list, device, num_labels,
    #                  task_name, tr_loss, global_step=0, epoch=-1, explainer=explainer)
    #     else:
    #         args.test = True
    #         explain(args, model, processor, tokenizer, output_mode, label_list, device)
    if not args.explain:
        args.test = True
        print('--Test_args.test: %s' % str(args.test))  #Test_args.test: True
        validate(args,
                 model,
                 processor,
                 tokenizer,
                 output_mode,
                 label_list,
                 device,
                 num_labels,
                 task_name,
                 tr_loss,
                 global_step=888,
                 epoch=-1,
                 explainer=explainer)
        args.test = False
    else:
        print('--Test_args.test: %s' % str(args.test))  # Test_args.test: True
        args.test = True
        explain(args, model, processor, tokenizer, output_mode, label_list,
                device)
        args.test = False