Пример #1
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {
                'input_ids':
                batch[0],
                'attention_mask':
                batch[1],
                'token_type_ids':
                batch[2] if args.model_type in ['bert', 'xlnet'] else
                None,  # XLM don't use segment_ids
                'labels':
                batch[3]
            }
            ouputs = model(**inputs)
            loss = ouputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Пример #2
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model. """
    tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)], 'weight_decay': args.weight_decay},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]
    optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args.warmup_steps, t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError('Please install apex from https://www.github.com/nvidia/apex to use fp16 training.')
        model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Train!
    logger.info('***** Running training *****')
    logger.info('   Num examples = %d', len(train_dataset))
    logger.info('   Num Epochs = %d', args.num_train_epochs)
    logger.info('   Instantaneous batch size per GPU = %d', args.per_gpu_train_batch_size)
    logger.info('   Total train batch size (w. parallel & accumulation) = %d',
                args.train_batch_size * args.gradient_accumulation_steps)
    logger.info('   Gradient Accumulation steps = %d', args.gradient_accumulation_steps)
    logger.info('   Total optimization steps = %d', t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs), desc='Epoch')
    set_seed(args)  # Added here for reproductibility

    max_val_acc = 0
    max_val_f1 = 0

    for _ in train_iterator:
        for step, batch in enumerate(train_dataloader):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            inputs = {'input_ids':              batch[0],
                      'attention_mask':         batch[1],
                      'token_type_ids':         batch[2],
                      'labels':                 batch[3],
                      'ct_clf_input_ids':       batch[4],
                      'ct_clf_attention_mask':  batch[5],
                      'ct_clf_token_type_ids':  batch[6],
                      'categories':             batch[7],
                      'hand_features':          batch[8]}
            outputs = model(**inputs)
            loss, clf_loss = outputs[0][0], outputs[1][0]  # model outputs are always tuple in pytorch_transformers (see doc)

            total_loss = loss + clf_loss
            if args.n_gpu > 1:
                total_loss = total_loss.mean()
            if args.gradient_accumulation_steps > 1:
                total_loss = total_loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_los(total_loss.optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm)
            else:
                total_loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)

            tr_loss += total_loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                optimizer.step()
                scheduler.step()
                model.zero_grad()
                global_step += 1

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.evaluate_during_training:
                        result = evaluate(args, model, tokenizer)
                        for key, value in result.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                        if result['acc'] > max_val_acc:
                            max_val_acc = result['acc']
                        if result['f1'] > max_val_f1:
                            max_val_f1 = result['f1']
                            output_dir = os.path.join(args.output_dir, 'best_checkpoint')
                            if not os.path.exists(output_dir):
                                os.makedirs(output_dir)
                            model_to_save = model.module if hasattr(model, 'module') else model
                            model_to_save.save_pretrained(output_dir)
                            torch.save(args, 'training_args.bin')
                            logger.info('Saving model checkpoint with f1 {:.4f}'.format(max_val_f1))
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss-logging_loss)/args.logging_steps, global_step)
                    logging_loss = tr_loss

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    tb_writer.close()
    return global_step, tr_loss / global_step
Пример #3
0
def main():

    parser = argparse.ArgumentParser()
    parser.add_argument('--model_name',
                        type=str,
                        default='gpt2-medium',
                        help='pretrained model name')
    parser.add_argument("--do_train",
                        action='store_true',
                        default=True,
                        help="Whether to run training.")
    parser.add_argument(
        "--output_dir",
        default='fintuned_gpt',
        type=str,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument('--dataset', type=str, default='', required=True)
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--eval_batch_size', type=int, default=8)
    parser.add_argument('--num_prior', type=int, default=2)
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--max_grad_norm', type=int, default=1)
    parser.add_argument("--max_steps",
                        default=-1,
                        type=int,
                        help="If > 0: set total number of training \
                        steps to perform. Override num_train_epochs.")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before\
                        performing a backward/update pass.")
    parser.add_argument('--learning_rate', type=float, default=6.25e-5)
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--lm_coef', type=float, default=0.9)

    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()
    print(args)

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info("device: {}, n_gpu {}".format(device, n_gpu))

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Load tokenizer and model
    # This loading functions also add new tokens and embeddings called `special tokens`
    # These new embeddings will be fine-tuned on the RocStories dataset.
    # start_token, delimiter_token, clf_token

    special_tokens = ['<|endoftext|>', '<|endoftext|>', '<|cls|>']
    tokenizer = GPT2Tokenizer.from_pretrained(args.model_name,
                                              unk_token='<|endoftext|>',
                                              bos_token='<|endoftext|>',
                                              eos_token='<|endoftext|>',
                                              cls_token='<|cls|>')
    tokenizer.add_tokens(['<|cls|>'])
    special_tokens_ids = list(
        tokenizer.convert_tokens_to_ids(token) for token in special_tokens)
    model = GPT2DoubleHeadsModel.from_pretrained(args.model_name)
    model.resize_token_embeddings(new_num_tokens=int(len(tokenizer)))

    model.to(device)

    def tokenize_and_encode(obj):
        """ Tokenize and encode a nested object """
        if isinstance(obj, str):
            return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
        elif isinstance(obj, int):
            return obj
        return list(tokenize_and_encode(o) for o in obj)

    logger.info("Encoding dataset...")

    train_dataset = load_dataset(tokenizer,
                                 args.dataset,
                                 num_prior=args.num_prior)
    eval_dataset = load_dataset(tokenizer,
                                args.dataset,
                                num_prior=args.num_prior)

    datasets = (train_dataset, eval_dataset)
    encoded_datasets = tokenize_and_encode(datasets)

    # Compute the max input length for the Transformer
    max_length = model.config.n_positions // 2 - 2
    input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3  \
                        for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
    input_length = min(input_length, model.config.n_positions
                       )  # Max size of input for the pre-trained model

    # Prepare inputs tensors and dataloaders
    tensor_datasets = pre_process_datasets(encoded_datasets, input_length,
                                           max_length, *special_tokens_ids)
    train_tensor_dataset, eval_tensor_dataset = tensor_datasets[
        0], tensor_datasets[1]

    train_data = TensorDataset(*train_tensor_dataset)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    eval_data = TensorDataset(*eval_tensor_dataset)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # Prepare optimizer

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps //\
            (len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(train_dataloader)\
            // args.gradient_accumulation_steps * args.num_train_epochs

    param_optimizer = list(model.named_parameters())
    no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params':
        [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],
        'weight_decay':
        0.01
    }, {
        'params':
        [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
    model.train()
    for i, _ in enumerate(range(int(args.num_train_epochs))):
        print('Starting Epoch: {} of {}'.format(
            str(i + 1), str(int(args.num_train_epochs))))
        tr_loss = 0
        nb_tr_steps = 0
        tqdm_bar = tqdm(train_dataloader, desc="Training")
        for step, batch in enumerate(tqdm_bar):
            batch = tuple(t.to(device) for t in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels = batch
            losses = model(input_ids, mc_token_ids, lm_labels, mc_labels)
            loss = args.lm_coef * losses[0] + losses[1]
            loss.backward()
            scheduler.step()
            optimizer.step()
            optimizer.zero_grad()
            tr_loss += loss.item()
            exp_average_loss = loss.item(
            ) if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item(
            )
            nb_tr_steps += 1
            tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(
                exp_average_loss,
                scheduler.get_lr()[0])


# Save a trained model

# Save a trained model, configuration and tokenizer
    model_to_save = model.module if hasattr(
        model, 'module') else model  # Only save the model it-self

    # If we save using the predefined names, we can load using `from_pretrained`
    output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
    output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

    torch.save(model_to_save.state_dict(), output_model_file)
    model_to_save.config.to_json_file(output_config_file)
    tokenizer.save_vocabulary(args.output_dir)

    # Load a trained model and vocabulary that you have fine-tuned
    model = GPT2DoubleHeadsModel.from_pretrained(args.output_dir)
    tokenizer = GPT2Tokenizer.from_pretrained(args.output_dir)
    model.to(device)
Пример #4
0
def main():
    class DictAttr(dict):
        def __getattr__(self, key):
            if key not in self:
                raise AttributeError(key)
            return self[key]

        def __setattr__(self, key, value):
            self[key] = value

        def __delattr__(self, key):
            del self[key]

    args = DictAttr()
    args.model_name = 'openai-gpt'
    args.train_dataset = "data_in/ROCStories/cloze_test_val__spring2016 - cloze_test_ALL_val.csv"
    args.eval_dataset = "data_in/ROCStories/cloze_test_test__spring2016 - cloze_test_ALL_test.csv"
    args.train_batch_size = 8

    # parser = argparse.ArgumentParser()
    # parser.add_argument('--model_name', type=str, default='openai-gpt',
    #                     help='pretrained model name')
    parser.add_argument("--do_train",
                        action='store_true',
                        help="Whether to run training.")
    parser.add_argument("--do_eval",
                        action='store_true',
                        help="Whether to run eval on the dev set.")
    parser.add_argument(
        "--output_dir",
        default=None,
        type=str,
        required=True,
        help=
        "The output directory where the model predictions and checkpoints will be written."
    )
    parser.add_argument('--train_dataset', type=str, default='')
    parser.add_argument('--eval_dataset', type=str, default='')
    parser.add_argument('--seed', type=int, default=42)
    parser.add_argument('--num_train_epochs', type=int, default=3)
    parser.add_argument('--train_batch_size', type=int, default=8)
    parser.add_argument('--eval_batch_size', type=int, default=16)
    parser.add_argument("--adam_epsilon",
                        default=1e-8,
                        type=float,
                        help="Epsilon for Adam optimizer.")
    parser.add_argument('--max_grad_norm', type=int, default=1)
    parser.add_argument("--max_steps",
                        default=-1,
                        type=int,
                        help="If > 0: set total number of training \
                        steps to perform. Override num_train_epochs.")
    parser.add_argument('--gradient_accumulation_steps',
                        type=int,
                        default=1,
                        help="Number of updates steps to accumulate before\
                        performing a backward/update pass.")
    parser.add_argument('--learning_rate', type=float, default=6.25e-5)
    parser.add_argument("--warmup_steps",
                        default=0,
                        type=int,
                        help="Linear warmup over warmup_steps.")
    parser.add_argument('--lr_schedule', type=str, default='warmup_linear')
    parser.add_argument('--weight_decay', type=float, default=0.01)
    parser.add_argument('--lm_coef', type=float, default=0.9)
    parser.add_argument('--n_valid', type=int, default=374)

    parser.add_argument('--server_ip',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    parser.add_argument('--server_port',
                        type=str,
                        default='',
                        help="Can be used for distant debugging.")
    args = parser.parse_args()

    if args.server_ip and args.server_port:
        # Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
        import ptvsd
        print("Waiting for debugger attach")
        ptvsd.enable_attach(address=(args.server_ip, args.server_port),
                            redirect_output=True)
        ptvsd.wait_for_attach()

    random.seed(args.seed)
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)
    torch.cuda.manual_seed_all(args.seed)

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    n_gpu = torch.cuda.device_count()
    logger.info("device: {}, n_gpu {}".format(device, n_gpu))

    if not args.do_train and not args.do_eval:
        raise ValueError(
            "At least one of `do_train` or `do_eval` must be True.")

    if not os.path.exists(args.output_dir):
        os.makedirs(args.output_dir)

    # Load tokenizer and model
    # This loading functions also add new tokens and embeddings called `special tokens`
    # These new embeddings will be fine-tuned on the RocStories dataset
    special_tokens = ['_start_', '_delimiter_', '_classify_']
    tokenizer = OpenAIGPTTokenizer.from_pretrained(
        args.model_name, special_tokens=special_tokens)
    special_tokens_ids = list(
        tokenizer.convert_tokens_to_ids(token) for token in special_tokens)
    model = OpenAIGPTDoubleHeadsModel.from_pretrained(
        args.model_name, num_special_tokens=len(special_tokens))
    model.to(device)

    # Load and encode the datasets
    if not args.train_dataset and not args.eval_dataset:
        roc_stories = cached_path(ROCSTORIES_URL)

    def tokenize_and_encode(obj):
        """ Tokenize and encode a nested object """
        if isinstance(obj, str):
            return tokenizer.convert_tokens_to_ids(tokenizer.tokenize(obj))
        elif isinstance(obj, int):
            return obj
        return list(tokenize_and_encode(o) for o in obj)

    logger.info("Encoding dataset...")
    train_dataset = load_rocstories_dataset(args.train_dataset)
    #("Rick grew up in a troubled household. He never found good support in family, and turned to gangs. It wasn't long before Rick got shot in a robbery. The incident caused him to turn a new leaf.", 'He is happy now.', 'He joined a gang.', 0)

    eval_dataset = load_rocstories_dataset(args.eval_dataset)
    datasets = (train_dataset, eval_dataset)
    encoded_datasets = tokenize_and_encode(datasets)

    # Compute the max input length for the Transformer
    max_length = model.config.n_positions // 2 - 2
    input_length = max(len(story[:max_length]) + max(len(cont1[:max_length]), len(cont2[:max_length])) + 3  \
                           for dataset in encoded_datasets for story, cont1, cont2, _ in dataset)
    input_length = min(input_length, model.config.n_positions
                       )  # Max size of input for the pre-trained model

    # Prepare inputs tensors and dataloaders
    tensor_datasets = pre_process_datasets(encoded_datasets, input_length,
                                           max_length, *special_tokens_ids)
    train_tensor_dataset, eval_tensor_dataset = tensor_datasets[
        0], tensor_datasets[1]

    train_data = TensorDataset(*train_tensor_dataset)
    train_sampler = RandomSampler(train_data)
    train_dataloader = DataLoader(train_data,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    eval_data = TensorDataset(*eval_tensor_dataset)
    eval_sampler = SequentialSampler(eval_data)
    eval_dataloader = DataLoader(eval_data,
                                 sampler=eval_sampler,
                                 batch_size=args.eval_batch_size)

    # Prepare optimizer
    if args.do_train:
        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps //\
                (len(train_dataloader) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(train_dataloader)\
                // args.gradient_accumulation_steps * args.num_train_epochs

        param_optimizer = list(model.named_parameters())
        no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in param_optimizer
                if not any(nd in n for nd in no_decay)
            ],
            'weight_decay':
            0.01
        }, {
            'params':
            [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],
            'weight_decay':
            0.0
        }]
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=args.warmup_steps,
                                         t_total=t_total)

    if args.do_train:
        nb_tr_steps, tr_loss, exp_average_loss = 0, 0, None
        model.train()
        for _ in trange(int(args.num_train_epochs), desc="Epoch"):
            tr_loss = 0
            nb_tr_steps = 0
            tqdm_bar = tqdm(train_dataloader, desc="Training")
            for step, batch in enumerate(tqdm_bar):
                batch = tuple(t.to(device) for t in batch)
                input_ids, mc_token_ids, lm_labels, mc_labels = batch
                losses = model(input_ids, mc_token_ids, lm_labels, mc_labels)
                loss = args.lm_coef * losses[0] + losses[1]
                loss.backward()
                scheduler.step()
                optimizer.step()
                optimizer.zero_grad()
                tr_loss += loss.item()
                exp_average_loss = loss.item(
                ) if exp_average_loss is None else 0.7 * exp_average_loss + 0.3 * loss.item(
                )
                nb_tr_steps += 1
                tqdm_bar.desc = "Training loss: {:.2e} lr: {:.2e}".format(
                    exp_average_loss,
                    scheduler.get_lr()[0])

    # Save a trained model
    if args.do_train:
        # Save a trained model, configuration and tokenizer
        model_to_save = model.module if hasattr(
            model, 'module') else model  # Only save the model it-self

        # If we save using the predefined names, we can load using `from_pretrained`
        output_model_file = os.path.join(args.output_dir, WEIGHTS_NAME)
        output_config_file = os.path.join(args.output_dir, CONFIG_NAME)

        torch.save(model_to_save.state_dict(), output_model_file)
        model_to_save.config.to_json_file(output_config_file)
        tokenizer.save_vocabulary(args.output_dir)

        # Load a trained model and vocabulary that you have fine-tuned
        model = OpenAIGPTDoubleHeadsModel.from_pretrained(args.output_dir)
        tokenizer = OpenAIGPTTokenizer.from_pretrained(args.output_dir)
        model.to(device)

    if args.do_eval:
        model.eval()
        eval_loss, eval_accuracy = 0, 0
        nb_eval_steps, nb_eval_examples = 0, 0
        for batch in tqdm(eval_dataloader, desc="Evaluating"):
            batch = tuple(t.to(device) for t in batch)
            input_ids, mc_token_ids, lm_labels, mc_labels = batch
            with torch.no_grad():
                _, mc_loss, _, mc_logits = model(input_ids, mc_token_ids,
                                                 lm_labels, mc_labels)

            mc_logits = mc_logits.detach().cpu().numpy()
            mc_labels = mc_labels.to('cpu').numpy()
            tmp_eval_accuracy = accuracy(mc_logits, mc_labels)

            eval_loss += mc_loss.mean().item()
            eval_accuracy += tmp_eval_accuracy

            nb_eval_examples += input_ids.size(0)
            nb_eval_steps += 1

        eval_loss = eval_loss / nb_eval_steps
        eval_accuracy = eval_accuracy / nb_eval_examples
        train_loss = tr_loss / nb_tr_steps if args.do_train else None
        result = {
            'eval_loss': eval_loss,
            'eval_accuracy': eval_accuracy,
            'train_loss': train_loss
        }

        output_eval_file = os.path.join(args.output_dir, "eval_results.txt")
        with open(output_eval_file, "w") as writer:
            logger.info("***** Eval results *****")
            for key in sorted(result.keys()):
                logger.info("  %s = %s", key, str(result[key]))
                writer.write("%s = %s\n" % (key, str(result[key])))
Пример #5
0
class Trainer(object):
    def __init__(self, cfg):
        self.cfg = cfg
        if self.cfg.feature:
            net = FeatureModel(self.cfg)
        else:
            net = BasicModel(self.cfg)
        # print(tuple(self.cfg.adam_betas))
        print(net)

        if self.cfg.cuda:
            net = net.cuda()
            if self.cfg.parallel and torch.cuda.device_count() > 1:
                print("Let's use", torch.cuda.device_count(), "GPUs!")
                net = nn.DataParallel(net)
        self.net = net
        self.start_epoch = 0
        if self.cfg.pretrained is not None:
            self.load_pretrained_net(pretrained=self.cfg.pretrained)

        self.index2label = {
            0: 'A',
            1: 'B',
            2: 'C',
            3: 'D',
            4: 'E'
        } if self.cfg.task == 'commonsense_qa' else {
            0: '1',
            1: '2'
        }

        self.best = 1. / len(self.index2label)

        if self.cfg.task == 'winograde':
            self.cfg.task += '_' + self.cfg.train_size

    def train(self, train_db, val_db):
        # if self.cfg.cuda and self.cfg.parallel:
        #     net = self.net.module
        # else:
        #     net = self.net
        # net = self.net
        # print(net)
        if self.cfg.tensorboard_logdir is not None:
            summary_writer = SummaryWriter(self.cfg.tensorboard_logdir)
        else:
            summary_writer = SummaryWriter(
                osp.join(self.cfg.log_dir, self.cfg.task, 'tensorboard',
                         self.cfg.model_name))

        # log_per_steps = self.cfg.accumulation_steps * self.cfg.log_per_steps

        log_dir = osp.join(self.cfg.log_dir, self.cfg.task,
                           self.cfg.model_name)
        if not osp.exists(log_dir):
            os.makedirs(log_dir)

        code_dir = osp.join(log_dir, 'code')
        if not osp.exists(code_dir):
            os.makedirs(code_dir)

        shutil.copy('./train.py', osp.join(code_dir, 'train.py'))
        shutil.copy('./commonsense_dataset.py',
                    osp.join(code_dir, 'commonsense_dataset.py'))

        logz.configure_output_dir(log_dir)
        logz.save_config(self.cfg)

        train_loader = DataLoader(train_db,
                                  batch_size=self.cfg.batch_size,
                                  shuffle=True,
                                  num_workers=self.cfg.num_workers)

        # self.optimizer = BertAdam(net.parameters(), lr=cfg.lr, warmup=cfg.warmup)
        # self.scheduler = optim.lr_self.scheduler.StepLR(self.optimizer, step_size=3, gamma=0.8)

        num_train_steps = int(
            len(train_loader) / self.cfg.accumulation_steps * self.cfg.epochs)
        num_warmup_steps = int(num_train_steps * self.cfg.warmup)

        no_decay = ['bias', 'LayerNorm.weight']
        not_optim = []

        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in self.net.named_parameters()
                if (not any(nd in n for nd in no_decay)) and (not any(
                    nd in n for nd in not_optim))
            ],
            'weight_decay':
            self.cfg.weight_decay
        }, {
            'params': [
                p for n, p in self.net.named_parameters()
                if (any(nd in n
                        for nd in no_decay)) and (not any(nd in n
                                                          for nd in not_optim))
            ],
            'weight_decay':
            0.0
        }]

        if self.cfg.fix_emb:
            for p in self.net.embedding.embeddings.parameters():
                p.requires_grad = False

        if self.cfg.ft_last_layer:
            for p in self.net.embedding.embeddings.parameters():
                p.requires_grad = False
            for i in range(10):
                for p in self.net.embedding.encoder.layer[i]:
                    p.requires_grad = False

        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=self.cfg.lr,
                               eps=self.cfg.adam_eps,
                               betas=eval(self.cfg.adam_betas))
        # self.optimizer = AdamW(self.net.parameters(), lr=self.cfg.lr, eps=1e-8)

        self.scheduler = WarmupLinearSchedule(self.optimizer,
                                              warmup_steps=num_warmup_steps,
                                              t_total=num_train_steps)
        loss_func = nn.CrossEntropyLoss()

        if self.cfg.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            self.net, self.optimizer = amp.initialize(
                self.net, self.optimizer, opt_level=self.cfg.fp16_opt_level)
        # self.scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_train_steps)

        # self.optimizer.set_self.scheduler(self.scheduler)

        torch.cuda.synchronize()
        self.start = time.time()
        self.net.zero_grad()
        self.batch_loss, self.batch_acc = [], []
        self.global_step = 0
        for epoch in range(self.start_epoch, self.cfg.epochs):

            print('Training...')
            torch.cuda.empty_cache()
            self.batch_loss, self.batch_acc = [], []
            for cnt, batch in tqdm(enumerate(train_loader)):
                self.net.train()

                input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs = self.batch_data(
                    batch)
                batch_input = (input_ids, input_mask, segment_ids, features,
                               fea_mask)
                # self.net.zero_grad()
                logits, probabilities, preds = self.net(
                    input_ids, input_mask, segment_ids, features, fea_mask)
                loss = loss_func(logits, labels).mean()
                # print(probabilities)

                # one_hot_labels = nn.functional.one_hot(labels, num_classes = Number_class[self.cfg.task.lower()]).float()
                # per_example_loss = -torch.sum(one_hot_labels * log_probs, dim=-1)
                # loss = torch.mean(per_example_loss)

                if self.cfg.accumulation_steps > 1:
                    loss = loss / self.cfg.accumulation_steps

                if self.cfg.fp16:
                    with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                        scaled_loss.backward()
                    if self.cfg.max_grad_norm > 0.0:
                        torch.nn.utils.clip_grad_norm_(
                            amp.master_params(self.optimizer),
                            self.cfg.max_grad_norm)
                else:
                    loss.backward()
                    if self.cfg.max_grad_norm > 0.0:
                        nn.utils.clip_grad_norm_(self.net.parameters(),
                                                 self.cfg.max_grad_norm)

                acc, _, _, _ = self.evaluate(preds, labels, input_indexs)

                self.batch_loss.append(loss.cpu().data.item() / len(input_ids))
                self.batch_acc.append(acc)

                if self.global_step == 0 and cnt == 0:
                    _ = self.update_log(summary_writer, epoch, val_db)

                if ((cnt + 1) % self.cfg.accumulation_steps) == 0:
                    # print(nn.utils.clip_grad_norm_(self.net.parameters(), max_norm=1e5))
                    self.optimizer.step()
                    self.scheduler.step()
                    self.net.zero_grad()
                    self.global_step += 1

                    if self.global_step % self.cfg.log_per_steps == 0:
                        val_acc = self.update_log(summary_writer, epoch,
                                                  val_db)
                        self.batch_loss, self.batch_acc = [], []

                        if self.cfg.save_ckpt:
                            if epoch >= (self.cfg.epochs / 4):
                                if self.best < val_acc:
                                    print('Saving checkpoint...')
                                    self.save_checkpoint(epoch, acc=val_acc)
                                    self.best = val_acc

            ##################################################################
            ## Checkpoint
            ##################################################################
            if len(self.batch_loss) > 0:
                val_acc = self.update_log(summary_writer, epoch, val_db)
                self.best = max(self.best, val_acc)
                self.batch_loss, self.batch_acc = [], []

            # val_wrong_qa = []
            # for q, a in zip(val_wrong, val_wrong_answer):
            #     val_wrong_qa.append([val_db.index2qid[q], trainer.index2label[a]])
            # epoch_wrong = {epoch: val_wrong_qa}
            if self.cfg.save_ckpt:
                if epoch >= (self.cfg.epochs / 4):
                    print('Saving checkpoint...')
                    self.save_checkpoint(epoch, True, acc=val_acc)
            torch.cuda.empty_cache()

        summary_writer.close()

    def update_log(self, summary_writer, epoch, val_db, inds=None):
        # print('Epoch %03d, iter %07d:'%(epoch, cnt))
        # print('loss: %05f, acc: %05f'%(np.mean(self.batch_loss), np.mean(self.batch_acc)))
        # # print(self.scheduler.get_lr()[0])
        # print('-------------------------')
        summary_writer.add_scalar('train_loss', np.mean(self.batch_loss),
                                  self.global_step)
        summary_writer.add_scalar('train_acc', np.mean(self.batch_acc),
                                  self.global_step)

        val_loss, val_acc, val_wrong, val_wrong_answer, eqs_ = self.validate(
            val_db)
        summary_writer.add_scalar('val_loss', np.mean(val_loss),
                                  self.global_step)
        summary_writer.add_scalar('val_acc', val_acc, self.global_step)
        summary_writer.add_scalar('lr',
                                  self.scheduler.get_lr()[0], self.global_step)

        # update optim self.scheduler
        torch.cuda.synchronize()
        logz.log_tabular("Time", time.time() - self.start)
        logz.log_tabular("Iteration", epoch)
        logz.log_tabular("TrainAverageLoss", np.mean(self.batch_loss))
        logz.log_tabular("TrainAverageAccu", np.mean(self.batch_acc))
        logz.log_tabular("ValAverageLoss", np.mean(val_loss))
        logz.log_tabular("ValAverageAccu", val_acc)

        if inds is not None:
            val_cnt = len(eqs_)
            eqs = [eqs_[i] for i in inds]
            eq0 = np.array(eqs[:int(val_cnt / 2)])
            eq1 = np.array(eqs[int(val_cnt / 2):])
            logz.log_tabular("ValAverageAccu0", eq0.sum() / len(eq0))
            logz.log_tabular("ValAverageAccu1", eq1.sum() / len(eq1))

        logz.dump_tabular()

        return val_acc

    def validate(self, val_db):
        ##################################################################
        ## Validation
        ##################################################################
        # if self.cfg.cuda and self.cfg.parallel:
        #     net = self.net.module
        # else:
        #     net = self.net
        # net = self.net
        print('Validation...')
        torch.cuda.empty_cache()
        self.net.eval()

        loss_func = nn.CrossEntropyLoss()

        val_loader = DataLoader(val_db,
                                batch_size=self.cfg.batch_size,
                                shuffle=False,
                                num_workers=self.cfg.num_workers)

        val_loss, preds_, labels_, input_indexs_ = [], [], [], []
        for _, batch in tqdm(enumerate(val_loader)):
            input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs = self.batch_data(
                batch)
            batch_input = (input_ids, input_mask, segment_ids, features,
                           fea_mask)

            with torch.no_grad():
                logits, probabilities, preds = self.net(
                    input_ids, input_mask, segment_ids, features, fea_mask)

                preds_.extend(preds)
                labels_.extend(labels)
                input_indexs_.extend(input_indexs)

                # if gate is not None:
                #     active_gate = torch.BoolTensor([g[0] >= 0.1 or g[1] >= 0.1 for g in gate])
                #     active_index = list(np.array(input_indexs[active_gate].cpu().data))
                #     val_activate_index += active_index
                loss = loss_func(logits, labels).mean()

                # acc, wrong_indexs, wrong_answer, eq = self.evaluate(preds, labels, input_indexs)
                val_loss.append(loss.cpu().data.item() / len(input_ids))
                # val_acc.append(acc)
                # val_wrong += wrong_indexs
                # val_wrong_answer += wrong_answer
                # eqs.extend(eq)

        val_acc, val_wrong, val_wrong_answer, eqs = self.evaluate(
            torch.Tensor(preds_), torch.Tensor(labels_),
            torch.Tensor(input_indexs_))
        # print(val_acc)

        return val_loss, val_acc, val_wrong, val_wrong_answer, eqs

    def test(self, test_db):
        ##################################################################
        ## Validation
        # ##################################################################
        # if self.cfg.cuda and self.cfg.parallel:
        #     net = self.net.module
        # else:
        #     net = self.net
        # net = self.net
        print('Validation...')
        torch.cuda.empty_cache()
        self.net.eval()

        test_loader = DataLoader(test_db,
                                 batch_size=self.cfg.batch_size,
                                 shuffle=False,
                                 num_workers=self.cfg.num_workers)

        answer = []

        for _, batch in tqdm(enumerate(test_loader)):
            input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs = self.batch_data(
                batch)
            batch_input = (input_ids, input_mask, segment_ids, features,
                           fea_mask)

            with torch.no_grad():
                logits, probabilities, preds = self.net(
                    input_ids, input_mask, segment_ids, features, fea_mask)
                input_indexs = list(np.array(input_indexs.cpu().data))
                preds = list(np.array(preds.cpu().data))
                for ind, pred in zip(input_indexs, preds):
                    answer.append(
                        (test_db.index2qid[ind], self.index2label[pred]))

        return answer

    def load_pretrained_net(self, pretrained):
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net
        # self.begin_epoch = int(pretrained_name.split('-')[1]) + 1
        print('loading ckpt from ', pretrained)

        assert osp.exists(pretrained)
        if self.cfg.cuda:
            checkpoint = torch.load(pretrained)
        else:
            checkpoint = torch.load(pretrained,
                                    map_location=lambda storage, loc: storage)

        net.load_state_dict(checkpoint['net'])

    def save_checkpoint(self, epoch, force=False, acc=None):
        # wrong_index_path = osp.join(self.cfg.log_dir, self.cfg.task, self.cfg.model_name, "wrong_index.jsonl")
        # with jsonlines.open(wrong_index_path, 'a+') as writer:
        #     writer.write(epoch_wrong)

        print(" [*] Saving checkpoints...")
        if self.cfg.cuda and self.cfg.parallel:
            net = self.net.module
        else:
            net = self.net

        checkpoint_dir = osp.join(self.cfg.log_dir, self.cfg.task,
                                  self.cfg.model_name)
        if not osp.exists(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        tail = ''
        if acc is not None:
            acc = str(round(acc, 5))[2:]
            tail += '-'
            tail += acc
        if force:
            tail += '-'
            tail += 'end'
        model_name = "ckpt-%03d%s.pkl" % (epoch, tail)

        print('saving ckpt to ', checkpoint_dir)
        if self.cfg.fp16:
            state = {
                'net': net.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'epoch': epoch,
                'amp': amp.state_dict()
            }
        else:
            state = {
                'net': net.state_dict(),
                'optimizer': self.optimizer.state_dict(),
                'epoch': epoch
            }
        # torch.save(net.state_dict(), osp.join(checkpoint_dir, model_name))
        torch.save(state, osp.join(checkpoint_dir, model_name))

    def batch_data(self, entry):
        features, fea_mask = None, None
        input_ids = entry['token_ids'].long()
        segment_ids = entry['segment_ids'].long()
        input_mask = entry['mask'].long()
        labels = entry['label_ids'].long()
        input_indexs = entry['index'].long()

        if self.cfg.feature:
            features = entry['feature'].float()
            fea_mask = entry['fea_mask'].long()

        # print(input_ids[0])
        # exit()

        if self.cfg.cuda:
            input_ids = input_ids.cuda()
            input_mask = input_mask.cuda()
            segment_ids = segment_ids.cuda()
            labels = labels.cuda()
            input_indexs = input_indexs.cuda()
            if self.cfg.feature:
                features = features.cuda()
                fea_mask = fea_mask.cuda()

        return input_ids, input_mask, segment_ids, features, fea_mask, labels, input_indexs

    def evaluate(self, pred, labels, input_indexs):
        eq = torch.eq(pred, labels)
        # print(labels.shape)
        wrong_indexs = list(np.array(input_indexs[~eq].cpu().data))
        wrong_answer = list(np.array(pred[~eq].cpu().data))
        correct = eq.sum().cpu().data.item()
        acc = correct / len(labels)

        return acc, wrong_indexs, wrong_answer, np.array(eq.cpu())
class Distiller:
    def __init__(self, params: dict, dataloader: Dataset,
                 token_probs: torch.tensor, student: nn.Module,
                 teacher: nn.Module):
        logger.info('Initializing Distiller')
        self.params = params
        self.dump_path = params.dump_path
        self.multi_gpu = params.multi_gpu
        self.fp16 = params.fp16

        self.student = student
        self.teacher = teacher

        self.dataloader = dataloader
        if self.params.n_gpu > 1:
            self.dataloader.split()
        self.get_iterator(seed=params.seed)

        self.temperature = params.temperature
        assert self.temperature > 0.

        self.alpha_ce = params.alpha_ce
        self.alpha_mlm = params.alpha_mlm
        self.alpha_mse = params.alpha_mse
        assert self.alpha_ce >= 0.
        assert self.alpha_mlm >= 0.
        assert self.alpha_mse >= 0.
        assert self.alpha_ce + self.alpha_mlm + self.alpha_mse > 0.

        self.mlm_mask_prop = params.mlm_mask_prop
        assert 0.0 <= self.mlm_mask_prop <= 1.0
        assert params.word_mask + params.word_keep + params.word_rand == 1.0
        self.pred_probs = torch.FloatTensor(
            [params.word_mask, params.word_keep, params.word_rand])
        self.pred_probs = self.pred_probs.to(
            f'cuda:{params.local_rank}'
        ) if params.n_gpu > 0 else self.pred_probs
        self.token_probs = token_probs.to(
            f'cuda:{params.local_rank}') if params.n_gpu > 0 else token_probs
        if self.fp16:
            self.pred_probs = self.pred_probs.half()
            self.token_probs = self.token_probs.half()

        self.epoch = 0
        self.n_iter = 0
        self.n_total_iter = 0
        self.n_sequences_epoch = 0
        self.total_loss_epoch = 0
        self.last_loss = 0
        self.last_loss_ce = 0
        self.last_loss_mlm = 0
        self.last_loss_mse = 0

        self.ce_loss_fct = nn.KLDivLoss(reduction='batchmean')
        self.mlm_loss_fct = nn.CrossEntropyLoss(ignore_index=-1)
        self.mse_loss_fct = nn.MSELoss(reduction='sum')

        logger.info('--- Initializing model optimizer')
        assert params.gradient_accumulation_steps >= 1
        self.num_steps_epoch = int(
            len(self.dataloader) / params.batch_size) + 1
        num_train_optimization_steps = int(
            self.num_steps_epoch / params.gradient_accumulation_steps *
            params.n_epoch) + 1
        warmup_steps = math.ceil(num_train_optimization_steps *
                                 params.warmup_prop)

        no_decay = ['bias', 'LayerNorm.weight']
        optimizer_grouped_parameters = [{
            'params': [
                p for n, p in student.named_parameters()
                if not any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            params.weight_decay
        }, {
            'params': [
                p for n, p in student.named_parameters()
                if any(nd in n for nd in no_decay) and p.requires_grad
            ],
            'weight_decay':
            0.0
        }]
        logger.info(
            "------ Number of trainable parameters (student): %i" % sum([
                p.numel() for p in self.student.parameters() if p.requires_grad
            ]))
        logger.info("------ Number of parameters (student): %i" %
                    sum([p.numel() for p in self.student.parameters()]))
        self.optimizer = AdamW(optimizer_grouped_parameters,
                               lr=params.learning_rate,
                               eps=params.adam_epsilon,
                               betas=(0.9, 0.98))
        self.scheduler = WarmupLinearSchedule(
            self.optimizer,
            warmup_steps=warmup_steps,
            t_total=num_train_optimization_steps)

        if self.fp16:
            try:
                from apex import amp
            except ImportError:
                raise ImportError(
                    "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
                )
            logger.info(
                f"Using fp16 training: {self.params.fp16_opt_level} level")
            self.student, self.optimizer = amp.initialize(
                self.student,
                self.optimizer,
                opt_level=self.params.fp16_opt_level)
            self.teacher = self.teacher.half()

        if self.multi_gpu:
            if self.fp16:
                from apex.parallel import DistributedDataParallel
                logger.info(
                    "Using apex.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(self.student)
            else:
                from torch.nn.parallel import DistributedDataParallel
                logger.info(
                    "Using nn.parallel.DistributedDataParallel for distributed training."
                )
                self.student = DistributedDataParallel(
                    self.student,
                    device_ids=[params.local_rank],
                    output_device=params.local_rank)

        self.is_master = params.is_master
        if self.is_master:
            logger.info('--- Initializing Tensorboard')
            self.tensorboard = SummaryWriter(
                log_dir=os.path.join(self.dump_path, 'log', 'train'))
            self.tensorboard.add_text(tag='config',
                                      text_string=str(self.params),
                                      global_step=0)

    def get_iterator(self, seed: int = None):
        """
        Initialize the data iterator.
        Each process has its own data iterator (iterating on his own random portion of the dataset).

        Input:
        ------
            seed: `int` - The random seed.
        """
        logger.info('--- Initializing Data Iterator')
        self.data_iterator = self.dataloader.get_iterator(seed=seed)

    def get_batch(self):
        """
        Call the data iterator to output a new batch.
        If the data iterator went through the whole dataset, create a new iterator.
        """
        assert hasattr(self, 'data_iterator')
        try:
            x = next(self.data_iterator)
        except StopIteration:
            logger.warning(
                '--- Went through the whole dataset. Creating new data iterator.'
            )
            self.data_iterator = self.dataloader.get_iterator()
            x = next(self.data_iterator)
        return x

    def prepare_batch(self, batch):
        """
        Prepare the batch: from the token_ids and the lenghts, compute the attention mask and the masked label for MLM.

        Input:
        ------
            batch: `Tuple`
                token_ids: `torch.tensor(bs, seq_length)` - The token ids for each of the sequence. It is padded.
                lengths: `torch.tensor(bs)` - The lengths of each of the sequences in the batch.

        Output:
        -------
            token_ids: `torch.tensor(bs, seq_length)` - The token ids after the modifications for MLM.
            attn_mask: `torch.tensor(bs, seq_length)` - The attention mask for the self-attention.
            mlm_labels: `torch.tensor(bs, seq_length)` - The masked languge modeling labels. There is a -1 where there is nothing to predict.
        """
        token_ids, lengths = batch
        token_ids, lengths = self.round_batch(x=token_ids, lengths=lengths)
        assert token_ids.size(0) == lengths.size(0)

        attn_mask = (torch.arange(token_ids.size(1),
                                  dtype=torch.long,
                                  device=lengths.device) < lengths[:, None])

        bs, max_seq_len = token_ids.size()
        mlm_labels = token_ids.new(token_ids.size()).copy_(token_ids)

        x_prob = self.token_probs[token_ids.flatten()]
        n_tgt = math.ceil(self.mlm_mask_prop * lengths.sum().item())
        tgt_ids = torch.multinomial(x_prob / x_prob.sum(),
                                    n_tgt,
                                    replacement=False)
        pred_mask = torch.zeros(
            bs * max_seq_len, dtype=torch.bool, device=token_ids.device
        )  # previously `dtype=torch.uint8`, cf pytorch 1.2.0 compatibility
        pred_mask[tgt_ids] = 1
        pred_mask = pred_mask.view(bs, max_seq_len)

        pred_mask[token_ids == self.params.special_tok_ids['pad_token']] = 0

        # mask a number of words == 0 [8] (faster with fp16)
        if self.fp16:
            n1 = pred_mask.sum().item()
            if n1 > 8:
                pred_mask = pred_mask.view(-1)
                n2 = max(n1 % 8, 8 * (n1 // 8))
                if n2 != n1:
                    pred_mask[torch.nonzero(pred_mask).view(-1)[:n1 - n2]] = 0
                pred_mask = pred_mask.view(bs, max_seq_len)
                assert pred_mask.sum().item() % 8 == 0, pred_mask.sum().item()

        _token_ids_real = token_ids[pred_mask]
        _token_ids_rand = _token_ids_real.clone().random_(
            self.params.vocab_size)
        _token_ids_mask = _token_ids_real.clone().fill_(
            self.params.special_tok_ids['mask_token'])
        probs = torch.multinomial(self.pred_probs,
                                  len(_token_ids_real),
                                  replacement=True)
        _token_ids = _token_ids_mask * (
            probs == 0).long() + _token_ids_real * (
                probs == 1).long() + _token_ids_rand * (probs == 2).long()
        token_ids = token_ids.masked_scatter(pred_mask, _token_ids)

        mlm_labels[
            ~pred_mask] = -1  # previously `mlm_labels[1-pred_mask] = -1`, cf pytorch 1.2.0 compatibility

        return token_ids, attn_mask, mlm_labels

    def round_batch(self, x: torch.tensor, lengths: torch.tensor):
        """
        For float16 only.
        Sub-sample sentences in a batch, and add padding, so that each dimension is a multiple of 8.

        Input:
        ------
            x: `torch.tensor(bs, seq_length)` - The token ids.
            lengths: `torch.tensor(bs, seq_length)` - The lengths of each of the sequence in the batch.

        Output:
        -------
            x:  `torch.tensor(new_bs, new_seq_length)` - The updated token ids.
            lengths: `torch.tensor(new_bs, new_seq_length)` - The updated lengths.
        """
        if not self.fp16 or len(lengths) < 8:
            return x, lengths

        # number of sentences == 0 [8]
        bs1 = len(lengths)
        bs2 = 8 * (bs1 // 8)
        assert bs2 > 0 and bs2 % 8 == 0
        if bs1 != bs2:
            idx = torch.randperm(bs1)[:bs2]
            lengths = lengths[idx]
            slen = lengths.max().item()
            x = x[idx, :slen]
        else:
            idx = None

        # sequence length == 0 [8]
        ml1 = x.size(1)
        if ml1 % 8 != 0:
            pad = 8 - (ml1 % 8)
            ml2 = ml1 + pad
            pad_id = self.params.special_tok_ids['pad_token']
            padding_tensor = torch.zeros(bs2,
                                         pad,
                                         dtype=torch.long,
                                         device=x.device).fill_(pad_id)
            x = torch.cat([x, padding_tensor], 1)
            assert x.size() == (bs2, ml2)

        assert x.size(0) % 8 == 0
        assert x.size(1) % 8 == 0
        return x, lengths

    def train(self):
        """
        The real training loop.
        """
        if self.is_master: logger.info('Starting training')
        self.student.train()
        self.teacher.eval()

        for _ in range(self.params.n_epoch):
            if self.is_master:
                logger.info(
                    f'--- Starting epoch {self.epoch}/{self.params.n_epoch-1}')

            iter_bar = trange(self.num_steps_epoch,
                              desc="-Iter",
                              disable=self.params.local_rank not in [-1, 0])
            for __ in range(self.num_steps_epoch):
                batch = self.get_batch()
                if self.params.n_gpu > 0:
                    batch = tuple(
                        t.to(f'cuda:{self.params.local_rank}') for t in batch)
                token_ids, attn_mask, mlm_labels = self.prepare_batch(
                    batch=batch)

                self.step(input_ids=token_ids,
                          attention_mask=attn_mask,
                          mlm_labels=mlm_labels)

                iter_bar.update()
                iter_bar.set_postfix({
                    'Last_loss':
                    f'{self.last_loss:.2f}',
                    'Avg_cum_loss':
                    f'{self.total_loss_epoch/self.n_iter:.2f}'
                })
            iter_bar.close()

            if self.is_master:
                logger.info(
                    f'--- Ending epoch {self.epoch}/{self.params.n_epoch-1}')
            self.end_epoch()

        if self.is_master:
            logger.info(f'Save very last checkpoint as `pytorch_model.bin`.')
            self.save_checkpoint(checkpoint_name=f'pytorch_model.bin')
            logger.info('Training is finished')

    def step(self, input_ids: torch.tensor, attention_mask: torch.tensor,
             mlm_labels: torch.tensor):
        """
        One optimization step: forward of student AND teacher, backward on the loss (for gradient accumulation),
        and possibly a parameter update (depending on the gradient accumulation).

        Input:
        ------
        input_ids: `torch.tensor(bs, seq_length)` - The token ids.
        attention_mask: `torch.tensor(bs, seq_length)` - The attention mask for self attention.
        mlm_labels: `torch.tensor(bs, seq_length)` - The masked language modeling labels.
        """
        s_logits = self.student(
            input_ids=input_ids,
            attention_mask=attention_mask)[0]  # (bs, seq_length, voc_size)
        with torch.no_grad():
            t_logits = self.teacher(
                input_ids=input_ids,
                attention_mask=attention_mask)[0]  # (bs, seq_length, voc_size)
        assert s_logits.size() == t_logits.size()

        #https://github.com/peterliht/knowledge-distillation-pytorch/blob/master/model/net.py#L100
        #https://github.com/peterliht/knowledge-distillation-pytorch/issues/2
        if self.params.restrict_ce_to_mask:
            mask = (mlm_labels > -1).unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        else:
            mask = attention_mask.unsqueeze(-1).expand_as(
                s_logits)  # (bs, seq_lenth, voc_size)
        s_logits_slct = torch.masked_select(
            s_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        s_logits_slct = s_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        t_logits_slct = torch.masked_select(
            t_logits,
            mask)  # (bs * seq_length * voc_size) modulo the 1s in mask
        t_logits_slct = t_logits_slct.view(-1, s_logits.size(
            -1))  # (bs * seq_length, voc_size) modulo the 1s in mask
        assert t_logits_slct.size() == s_logits_slct.size()

        loss_ce = self.ce_loss_fct(
            F.log_softmax(s_logits_slct / self.temperature, dim=-1),
            F.softmax(t_logits_slct / self.temperature,
                      dim=-1)) * (self.temperature)**2
        loss = self.alpha_ce * loss_ce
        if self.alpha_mlm > 0.:
            loss_mlm = self.mlm_loss_fct(s_logits.view(-1, s_logits.size(-1)),
                                         mlm_labels.view(-1))
            loss += self.alpha_mlm * loss_mlm
        if self.alpha_mse > 0.:
            loss_mse = self.mse_loss_fct(
                s_logits_slct, t_logits_slct) / s_logits_slct.size(
                    0)  # Reproducing batchmean reduction
            loss += self.alpha_mse * loss_mse

        self.total_loss_epoch += loss.item()
        self.last_loss = loss.item()
        self.last_loss_ce = loss_ce.item()
        if self.alpha_mlm > 0.:
            self.last_loss_mlm = loss_mlm.item()
        if self.alpha_mse > 0.:
            self.last_loss_mse = loss_mse.item()

        self.optimize(loss)

        self.n_sequences_epoch += input_ids.size(0)

    def optimize(self, loss):
        """
        Normalization on the loss (gradient accumulation or distributed training), followed by
        backward pass on the loss, possibly followed by a parameter update (depending on the gradient accumulation).
        Also update the metrics for tensorboard.
        """
        # Check for NaN
        if (loss != loss).data.any():
            logger.error('NaN detected')
            exit()

        if self.multi_gpu:
            loss = loss.mean()
        if self.params.gradient_accumulation_steps > 1:
            loss = loss / self.params.gradient_accumulation_steps

        if self.fp16:
            from apex import amp
            with amp.scale_loss(loss, self.optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            loss.backward()

        self.iter()
        if self.n_iter % self.params.gradient_accumulation_steps == 0:
            if self.fp16:
                torch.nn.utils.clip_grad_norm_(
                    amp.master_params(self.optimizer),
                    self.params.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(self.student.parameters(),
                                               self.params.max_grad_norm)
            self.optimizer.step()
            self.optimizer.zero_grad()
            self.scheduler.step()

    def iter(self):
        """
        Update global counts, write to tensorboard and save checkpoint.
        """
        self.n_iter += 1
        self.n_total_iter += 1

        if self.n_total_iter % self.params.log_interval == 0:
            self.log_tensorboard()
        if self.n_total_iter % self.params.checkpoint_interval == 0:
            self.save_checkpoint()

    def log_tensorboard(self):
        """
        Log into tensorboard. Only by the master process.
        """
        if not self.is_master:
            return

        for param_name, param in self.student.named_parameters():
            self.tensorboard.add_scalar(tag='parameter_mean/' + param_name,
                                        scalar_value=param.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag='parameter_std/' + param_name,
                                        scalar_value=param.data.std(),
                                        global_step=self.n_total_iter)
            if param.grad is None:
                continue
            self.tensorboard.add_scalar(tag="grad_mean/" + param_name,
                                        scalar_value=param.grad.data.mean(),
                                        global_step=self.n_total_iter)
            self.tensorboard.add_scalar(tag="grad_std/" + param_name,
                                        scalar_value=param.grad.data.std(),
                                        global_step=self.n_total_iter)

        self.tensorboard.add_scalar(tag="losses/cum_avg_loss_epoch",
                                    scalar_value=self.total_loss_epoch /
                                    self.n_iter,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss",
                                    scalar_value=self.last_loss,
                                    global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="losses/loss_ce",
                                    scalar_value=self.last_loss_ce,
                                    global_step=self.n_total_iter)
        if self.alpha_mlm > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mlm",
                                        scalar_value=self.last_loss_mlm,
                                        global_step=self.n_total_iter)
        if self.alpha_mse > 0.:
            self.tensorboard.add_scalar(tag="losses/loss_mse",
                                        scalar_value=self.last_loss_mse,
                                        global_step=self.n_total_iter)
        self.tensorboard.add_scalar(tag="learning_rate/lr",
                                    scalar_value=self.scheduler.get_lr()[0],
                                    global_step=self.n_total_iter)

        self.tensorboard.add_scalar(
            tag="global/memory_usage",
            scalar_value=psutil.virtual_memory()._asdict()['used'] / 1_000_000,
            global_step=self.n_total_iter)

    def end_epoch(self):
        """
        Finally arrived at the end of epoch (full pass on dataset).
        Do some tensorboard logging and checkpoint saving.
        """
        logger.info(
            f'{self.n_sequences_epoch} sequences have been trained during this epoch.'
        )

        if self.is_master:
            self.save_checkpoint(
                checkpoint_name=f'model_epoch_{self.epoch}.pth')
            self.tensorboard.add_scalar(tag='epoch/loss',
                                        scalar_value=self.total_loss_epoch /
                                        self.n_iter,
                                        global_step=self.epoch)

        self.epoch += 1
        self.n_sequences_epoch = 0
        self.n_iter = 0
        self.total_loss_epoch = 0

    def save_checkpoint(self, checkpoint_name: str = 'checkpoint.pth'):
        """
        Save the current state. Only by the master process.
        """
        if not self.is_master:
            return
        mdl_to_save = self.student.module if hasattr(
            self.student, 'module') else self.student
        mdl_to_save.config.save_pretrained(self.dump_path)
        state_dict = mdl_to_save.state_dict()
        torch.save(state_dict, os.path.join(self.dump_path, checkpoint_name))
Пример #7
0
def train(train_dataset, model, tokenizer):
    tb_writer = SummaryWriter()

    train_sampler = RandomSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args['train_batch_size'])

    t_total = len(train_dataloader) // args['gradient_accumulation_steps'] * args['num_train_epochs']

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [
        {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
         'weight_decay': args['weight_decay']},
        {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}
    ]

    warmup_steps = math.ceil(t_total * args['warmup_ratio'])
    args['warmup_steps'] = warmup_steps if args['warmup_steps'] == 0 else args['warmup_steps']

    optimizer = AdamW(optimizer_grouped_parameters, lr=args['learning_rate'], eps=args['adam_epsilon'])
    # optimizer = AdamW(list(child.parameters()), lr=args['learning_rate'], eps=args['adam_epsilon'])
    scheduler = WarmupLinearSchedule(optimizer, warmup_steps=args['warmup_steps'], t_total=t_total)

    if args['fp16']:
        try:
            from apex import amp
        except ImportError:
            raise ImportError("Please install apex from https://www.github.com/nvidia/apex to use fp16 training.")
        model, optimizer = amp.initialize(model, optimizer, opt_level=args['fp16_opt_level'])

    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args['num_train_epochs'])
    logger.info("  Total train batch size  = %d", args['train_batch_size'])
    logger.info("  Gradient Accumulation steps = %d", args['gradient_accumulation_steps'])
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args['num_train_epochs']), desc="Epoch")

    for _ in train_iterator:
        epoch_iterator = tqdm_notebook(train_dataloader, desc="Iteration")
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(device) for t in batch)
            inputs = {'input_ids': batch[0],
                      'attention_mask': batch[1],
                      'token_type_ids': batch[2] if args['model_type'] in ['bert', 'xlnet'] else None,
                      # XLM don't use segment_ids
                      'labels': batch[3]}
            outputs = model(**inputs)
            loss = outputs[0]  # model outputs are always tuple in pytorch-transformers (see doc)
            print("\r%f" % loss, end='')

            if args['gradient_accumulation_steps'] > 1:
                loss = loss / args['gradient_accumulation_steps']

            if args['fp16']:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args['max_grad_norm'])

            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), args['max_grad_norm'])

            tr_loss += loss.item()
            if (step + 1) % args['gradient_accumulation_steps'] == 0:
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if args['logging_steps'] > 0 and global_step % args['logging_steps'] == 0:
                    # Log metrics
                    if args[
                        'evaluate_during_training']:  # Only evaluate when single GPU otherwise metrics may not average well
                        results, _ = evaluate(model, tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args['logging_steps'], global_step)
                    logging_loss = tr_loss

                if args['save_steps'] > 0 and global_step % args['save_steps'] == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(args['output_dir'], 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(model,
                                                            'module') else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    logger.info("Saving model checkpoint to %s", output_dir)

    return global_step, tr_loss / global_step
Пример #8
0
def train(args, train_dataloader, model, encoder_tokenizer, decoder_tokenizer,
          table_name):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model, device_ids=range(args.n_gpu)).to(
            args.device)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_dataloader.num_examples)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0

    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    n_iter = int(args.num_train_epochs) * len(train_dataloader)

    tmp_list = []
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            tokenized_text0, tokenized_text1, tokenized_text_lengths = batch
            inputs, labels = tokenized_text1.to(
                args.device), tokenized_text1.to(args.device)

            model.train()

            outputs = model(inputs,
                            labels=labels,
                            label_ignore=decoder_tokenizer.pad_token_id)
            loss = outputs[0].mean(
            )  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean()

            if args.use_philly:
                print("PROGRESS: {}%".format(
                    round(
                        100 * (step + epoch * len(epoch_iterator)) /
                        (int(args.num_train_epochs) * len(epoch_iterator)),
                        4)))
                print("EVALERR: {}%".format(loss))

            epoch_iterator.set_description((
                f'iter: {step +  epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
            ))

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()

                scheduler.step()  # Update learning rate schedule

                model.zero_grad()

                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model_vae, encoder_tokenizer,
                                           decoder_tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:

                    # Save decoder model checkpoint
                    output_decoder_dir = os.path.join(
                        args.output_dir,
                        'checkpoint-decoder-{}'.format(global_step))

                    if not os.path.exists(output_decoder_dir):
                        os.makedirs(output_decoder_dir)

                    model_decoder_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    if args.use_philly:
                        save_solid = False
                        while not save_solid:
                            try:
                                model_decoder_to_save.save_pretrained(
                                    output_decoder_dir)
                                torch.save(
                                    args,
                                    os.path.join(output_decoder_dir,
                                                 'training_args.bin'))
                                logger.info("Saving model checkpoint to %s",
                                            output_decoder_dir)
                                save_solid = True
                            except:
                                pass
                    else:
                        model_decoder_to_save.save_pretrained(
                            output_decoder_dir)
                        torch.save(
                            args,
                            os.path.join(output_decoder_dir,
                                         'training_args.bin'))
                        logger.info("Saving model checkpoint to %s",
                                    output_decoder_dir)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Пример #9
0
def train(args, train_dataloader, model_vae, decoder_tokenizer, table_name):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    # train_sampler = RandomSampler(train_dataset) if args.local_rank == -1 else DistributedSampler(train_dataset)
    # train_dataloader = DataLoader(train_dataset, sampler=train_sampler, batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)

    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model_vae.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model_vae.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]

    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model_vae, optimizer = amp.initialize(model_vae,
                                              optimizer,
                                              opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model_vae = torch.nn.DataParallel(model_vae,
                                          device_ids=range(args.n_gpu)).to(
                                              args.device)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model_vae = torch.nn.parallel.DistributedDataParallel(
            model_vae,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", train_dataloader.num_examples)
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model_vae.zero_grad()

    # model_vae = model_vae.module if hasattr(model_vae, 'module') else model_vae  # Take care of distributed/parallel training

    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    n_iter = int(args.num_train_epochs) * len(train_dataloader)
    beta_t_list = frange_cycle_zero_linear(n_iter,
                                           start=0.0,
                                           stop=args.beta,
                                           n_cycle=1,
                                           ratio_increase=args.ratio_increase,
                                           ratio_zero=args.ratio_zero)

    tmp_list = []
    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)
    for epoch in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):

            tokenized_text1, tokenized_text_lengths = batch
            # tokenized_text0 = tokenized_text0.to(args.device)
            # tokenized_text1 = tokenized_text1.to(args.device)
            # prepare input-output data for reconstruction

            labels = tokenized_text1

            tokenized_text1 = tokenized_text1.to(args.device)
            # inputs = inputs.to(args.device)
            labels = labels.to(args.device)

            model_vae.train()

            beta_t = beta_t_list[step + epoch * len(epoch_iterator)]
            model_vae.module.args.beta = beta_t

            if beta_t == 0.0:
                model_vae.module.args.fb_mode = 0
            else:
                model_vae.module.args.fb_mode = 1

            if args.use_deterministic_connect:
                model_vae.module.args.fb_mode = 2

            loss_rec, loss = model_vae(labels)

            # Chunyuan: loss_rec size is [4], while latent_z size is [12]
            if args.n_gpu > 1:
                loss_rec = loss_rec.mean(
                )  # mean() to average on multi-gpu parallel training
                # loss_kl = loss_kl.mean()
                loss = loss.mean()

            if args.use_philly:
                print("PROGRESS: {}%".format(
                    round(
                        100 * (step + epoch * len(epoch_iterator)) /
                        (int(args.num_train_epochs) * len(epoch_iterator)),
                        4)))
                print("EVALERR: {}%".format(loss_rec))

            epoch_iterator.set_description((
                f'iter: {step +  epoch*len(epoch_iterator) }; loss: {loss.item():.3f}; '
                f'loss_rec: {loss_rec.item():.3f};'
                f'beta: {model_vae.module.args.beta:.3f}'))

            # if global_step % 5 == 0:
            #     row = {
            #             'PartitionKey': 'MILU_Rule_Rule_Template',
            #             'RowKey': str(datetime.now()),
            #             'ExpName' : args.ExpName,
            #             'iter': str( step +  epoch*len(epoch_iterator) ),
            #             'loss': str( loss.item()),
            #             'loss_rec': str(loss_rec.item()),
            #             'loss_kl': str(loss_kl.item()),
            #             'beta': str(model_vae.args.beta)
            #         }
            #     # pdb.set_trace()
            #     ts.insert_entity(table_name, row)

            # pdb.set_trace()

            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model_vae.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()

                scheduler.step()  # Update learning rate schedule

                model_vae.zero_grad()

                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model_vae, decoder_tokenizer)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)
                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

                if args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    save_checkpoint(model_vae, optimizer, global_step, args)

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step, optimizer
Пример #10
0
    def fit(
        self,
        token_ids,
        input_mask,
        labels,
        val_token_ids,
        val_input_mask,
        val_labels,
        token_type_ids=None,
        val_token_type_ids=None,
        verbose=True,
        logging_steps=0,
        save_steps=0,
        val_steps=0,
    ):
        """Fine-tunes the XLNet classifier using the given training data.

        Args:
            token_ids (list): List of training token id lists.
            input_mask (list): List of input mask lists.
            labels (list): List of training labels.
            token_type_ids (list, optional): List of lists. Each sublist
                contains segment ids indicating if the token belongs to
                the first sentence(0) or second sentence(1). Only needed
                for two-sentence tasks.
            verbose (bool, optional): If True, shows the training progress and
                loss values. Defaults to True.
        """

        device = get_device("cpu" if self.num_gpus == 0
                            or not torch.cuda.is_available() else "gpu")
        self.model = move_to_device(self.model, device, self.num_gpus)

        token_ids_tensor = torch.tensor(token_ids, dtype=torch.long)
        input_mask_tensor = torch.tensor(input_mask, dtype=torch.long)
        labels_tensor = torch.tensor(labels, dtype=torch.long)

        val_token_ids_tensor = torch.tensor(val_token_ids, dtype=torch.long)
        val_input_mask_tensor = torch.tensor(val_input_mask, dtype=torch.long)
        val_labels_tensor = torch.tensor(val_labels, dtype=torch.long)

        if token_type_ids:
            token_type_ids_tensor = torch.tensor(token_type_ids,
                                                 dtype=torch.long)
            val_token_type_ids_tensor = torch.tensor(val_token_type_ids,
                                                     dtype=torch.long)

            train_dataset = TensorDataset(token_ids_tensor, input_mask_tensor,
                                          token_type_ids_tensor, labels_tensor)

            val_dataset = TensorDataset(
                val_token_ids_tensor,
                val_input_mask_tensor,
                val_token_type_ids_tensor,
                val_labels_tensor,
            )

        else:

            train_dataset = TensorDataset(token_ids_tensor, input_mask_tensor,
                                          labels_tensor)

            val_dataset = TensorDataset(val_token_ids_tensor,
                                        val_input_mask_tensor,
                                        val_labels_tensor)

        # define optimizer and model parameters
        param_optimizer = list(self.model.named_parameters())
        no_decay = ["bias", "LayerNorm.weight"]
        optimizer_grouped_parameters = [
            {
                "params": [
                    p for n, p in param_optimizer
                    if not any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                self.weight_decay,
            },
            {
                "params": [
                    p for n, p in param_optimizer
                    if any(nd in n for nd in no_decay)
                ],
                "weight_decay":
                0.0,
            },
        ]

        val_sampler = RandomSampler(val_dataset)

        val_dataloader = DataLoader(val_dataset,
                                    sampler=val_sampler,
                                    batch_size=self.batch_size)

        num_examples = len(token_ids)
        num_batches = int(np.ceil(num_examples / self.batch_size))
        num_train_optimization_steps = num_batches * self.num_epochs

        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=self.lr,
                          eps=self.adam_eps)
        scheduler = WarmupLinearSchedule(optimizer,
                                         warmup_steps=self.warmup_steps,
                                         t_total=num_train_optimization_steps)

        global_step = 0
        self.model.train()
        optimizer.zero_grad()
        for epoch in range(self.num_epochs):

            train_sampler = RandomSampler(train_dataset)

            train_dataloader = DataLoader(train_dataset,
                                          sampler=train_sampler,
                                          batch_size=self.batch_size)

            tr_loss = 0.0
            logging_loss = 0.0
            val_loss = 0.0

            for i, batch in enumerate(tqdm(train_dataloader,
                                           desc="Iteration")):
                if token_type_ids:
                    x_batch, mask_batch, token_type_ids_batch, y_batch = tuple(
                        t.to(device) for t in batch)
                else:
                    token_type_ids_batch = None
                    x_batch, mask_batch, y_batch = tuple(
                        t.to(device) for t in batch)

                outputs = self.model(
                    input_ids=x_batch,
                    token_type_ids=token_type_ids_batch,
                    attention_mask=mask_batch,
                    labels=y_batch,
                )

                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers

                loss.sum().backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(),
                                               self.max_grad_norm)

                tr_loss += loss.sum().item()
                optimizer.step()
                # Update learning rate schedule
                scheduler.step()
                optimizer.zero_grad()
                global_step += 1
                # logging of learning rate and loss
                if logging_steps > 0 and global_step % logging_steps == 0:
                    mlflow.log_metric("learning rate",
                                      scheduler.get_lr()[0],
                                      step=global_step)
                    mlflow.log_metric(
                        "training loss",
                        (tr_loss - logging_loss) /
                        (logging_steps * self.batch_size),
                        step=global_step,
                    )
                    logging_loss = tr_loss
                # model checkpointing
                if save_steps > 0 and global_step % save_steps == 0:
                    checkpoint_dir = os.path.join(os.getcwd(), "checkpoints")
                    if not os.path.isdir(checkpoint_dir):
                        os.makedirs(checkpoint_dir)
                    checkpoint_path = checkpoint_dir + "/" + str(
                        global_step) + ".pth"
                    torch.save(self.model.state_dict(), checkpoint_path)
                    mlflow.log_artifact(checkpoint_path)
                # model validation
                if val_steps > 0 and global_step % val_steps == 0:
                    # run model on validation set
                    self.model.eval()
                    val_loss = 0.0
                    for j, val_batch in enumerate(val_dataloader):
                        if token_type_ids:
                            val_x_batch, val_mask_batch, val_token_type_ids_batch, \
                            val_y_batch = tuple(
                                t.to(device) for t in val_batch
                            )
                        else:
                            token_type_ids_batch = None
                            val_x_batch, val_mask_batch, val_y_batch = tuple(
                                t.to(device) for t in val_batch)
                        val_outputs = self.model(
                            input_ids=val_x_batch,
                            token_type_ids=val_token_type_ids_batch,
                            attention_mask=val_mask_batch,
                            labels=val_y_batch,
                        )
                        vloss = val_outputs[0]
                        val_loss += vloss.sum().item()
                    mlflow.log_metric("validation loss",
                                      val_loss / len(val_dataset),
                                      step=global_step)
                    self.model.train()

                if verbose:
                    if i % ((num_batches // 10) + 1) == 0:
                        if val_loss > 0:
                            print(
                                "epoch:{}/{}; batch:{}->{}/{}; average training loss:{:.6f};\
                                 average val loss:{:.6f}".format(
                                    epoch + 1,
                                    self.num_epochs,
                                    i + 1,
                                    min(i + 1 + num_batches // 10,
                                        num_batches),
                                    num_batches,
                                    tr_loss / (i + 1),
                                    val_loss / (j + 1),
                                ), )
                        else:
                            print(
                                "epoch:{}/{}; batch:{}->{}/{}; average train loss:{:.6f}"
                                .format(
                                    epoch + 1,
                                    self.num_epochs,
                                    i + 1,
                                    min(i + 1 + num_batches // 10,
                                        num_batches),
                                    num_batches,
                                    tr_loss / (i + 1),
                                ))
        checkpoint_dir = os.path.join(os.getcwd(), "checkpoints")
        if not os.path.isdir(checkpoint_dir):
            os.makedirs(checkpoint_dir)
        checkpoint_path = checkpoint_dir + "/" + "final" + ".pth"
        torch.save(self.model.state_dict(), checkpoint_path)
        mlflow.log_artifact(checkpoint_path)
        # empty cache
        del [x_batch, y_batch, mask_batch, token_type_ids_batch]
        if val_steps > 0:
            del [
                val_x_batch, val_y_batch, val_mask_batch,
                val_token_type_ids_batch
            ]
        torch.cuda.empty_cache()
Пример #11
0
def train(args, train_dataset, model, tokenizer):
    """ Train the model """
    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)

    data_loaders = []
    if args.pacing_function != "":
        # values_file = args.curriculum_file.split("_3")[0] + '_values_3'
        values_file = args.curriculum_file
        logger.info("Using curriculum scoring values from file " + values_file)
        if 'random' in values_file:
            logger.info("Randomizing values for random scoring function.")
            instances_scores = random.sample(range(len(train_dataset)),
                                             len(train_dataset))
        else:
            instances_scores = read_scores_file(values_file)

        #some value files do not repeat the scoring function for each doc.
        if len(instances_scores) != len(train_dataset):
            candidates_per_q = len(train_dataset) / len(instances_scores)
            filled_instances_scores = []
            for v in instances_scores:
                for i in range(int(candidates_per_q)):
                    filled_instances_scores.append(v)
            instances_scores = filled_instances_scores

        assert len(instances_scores) == len(train_dataset)

        c = [v for v in zip(instances_scores, train_dataset)]
        c = sorted(c, key=lambda x: x[0], reverse=args.invert_cl_values)
        ordered_train_dataset = [v[1] for v in c]
        c0 = 0.33

        train_data = ordered_train_dataset[0:int(c0 *
                                                 len(ordered_train_dataset))]
        train_sampler = RandomSampler(
            train_data) if args.local_rank == -1 else DistributedSampler(
                train_data)
        train_dataloader = DataLoader(train_data,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        data_loaders.append(
            ('pacing_function_' + args.pacing_function, train_dataloader))

        s_for_count_only = RandomSampler(
            ordered_train_dataset
        ) if args.local_rank == -1 else DistributedSampler(
            ordered_train_dataset)
        t_for_count_only = DataLoader(ordered_train_dataset,
                                      sampler=s_for_count_only,
                                      batch_size=args.train_batch_size)
        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps // (
                len(t_for_count_only) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(
                t_for_count_only
            ) // args.gradient_accumulation_steps * args.num_train_epochs

    elif args.curriculum_file != "":
        logger.info("Using curriculum from file " + args.curriculum_file)
        logger.info("Additive sets : " + str(args.use_additive_cl))
        cl_m = read_curriculum_file(args.curriculum_file)
        all_idxs = []
        for phase in range(len(cl_m.keys())):
            idx = cl_m[phase]
            all_idxs = all_idxs + idx

            if args.use_additive_cl:
                idx = all_idxs
            logger.info("Phase " + str(phase) + " has " + str(len(idx)) +
                        " instances.")

            train_data = [train_dataset[i] for i in idx]
            train_sampler = RandomSampler(
                train_data) if args.local_rank == -1 else DistributedSampler(
                    train_data)
            train_dataloader = DataLoader(train_data,
                                          sampler=train_sampler,
                                          batch_size=args.train_batch_size)
            data_loaders.append(('phase_' + str(phase), train_dataloader))
        if args.use_additive_cl:
            t_total = len(
                data_loaders[-1][1]
            ) // args.gradient_accumulation_steps * args.num_train_epochs
        else:
            t_total = sum([
                len(loader) for _, loader in data_loaders
            ]) // args.gradient_accumulation_steps * args.num_train_epochs
    else:
        train_sampler = RandomSampler(
            train_dataset) if args.local_rank == -1 else DistributedSampler(
                train_dataset)
        train_dataloader = DataLoader(train_dataset,
                                      sampler=train_sampler,
                                      batch_size=args.train_batch_size)
        data_loaders.append(('all_random_batches', train_dataloader))

        if args.max_steps > 0:
            t_total = args.max_steps
            args.num_train_epochs = args.max_steps // (
                len(train_dataloader) // args.gradient_accumulation_steps) + 1
        else:
            t_total = len(
                train_dataloader
            ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    # optimizer = torch.optim.Adam(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon, amsgrad=True)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)
    logger.info("  percentage by epoch = %f", args.percentage_data_by_epoch)
    logger.info("  data_loaders = %s", data_loaders)
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    best_model = None
    best_map = 0.0
    model.zero_grad()
    epochs = args.num_train_epochs
    if len(data_loaders) > 1:
        assert epochs % len(data_loaders) == 0
        epochs = epochs / len(data_loaders)

    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    for loader_name, train_dataloader in data_loaders:
        for epoch_i in range(int(epochs)):
            logger.info("Starting epoch " + str(epoch_i + 1))
            logger.info("Training with " + loader_name)

            step = 0
            while True:
                current_data_iter = iter(train_dataloader)
                batch = next(current_data_iter)

                model.train()
                batch = tuple(t.to(args.device) for t in batch)
                inputs = {
                    'input_ids':
                    batch[0],
                    'attention_mask':
                    batch[1],
                    'token_type_ids':
                    batch[2] if args.model_type in ['bert', 'xlnet'] else
                    None,  # XLM don't use segment_ids
                    'labels':
                    batch[3]
                }
                outputs = model(**inputs)
                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

                if args.n_gpu > 1:
                    loss = loss.mean(
                    )  # mean() to average on multi-gpu parallel training
                if args.gradient_accumulation_steps > 1:
                    loss = loss / args.gradient_accumulation_steps

                if args.fp16:
                    with amp.scale_loss(loss, optimizer) as scaled_loss:
                        scaled_loss.backward()
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    loss.backward()
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                tr_loss += loss.item()
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    scheduler.step()  # Update learning rate schedule
                    optimizer.step()
                    model.zero_grad()
                    global_step += 1

                    if args.local_rank in [
                            -1, 0
                    ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                        # Log metrics
                        if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                            logger.info("Iter = " + str(global_step))
                            logger.info("lr = " + str(scheduler.get_lr()[0]))
                            logger.info("loss = " +
                                        str((tr_loss - logging_loss) /
                                            args.logging_steps))
                            if args.pacing_function != "":
                                logger.info("Current data iter size: " +
                                            str(len(current_data_iter)))
                            results = evaluate(args, model, tokenizer)
                            for key, value in results.items():
                                tb_writer.add_scalar('eval_{}'.format(key),
                                                     value, global_step)
                            if results['map'] > best_map:
                                best_map = results['map']
                                output_dir = os.path.join(
                                    args.output_dir,
                                    'checkpoint-best_' + args.run_name)
                                if not os.path.exists(output_dir):
                                    os.makedirs(output_dir)
                                model_to_save = model.module if hasattr(
                                    model, 'module'
                                ) else model  # Take care of distributed/parallel training
                                model_to_save.save_pretrained(output_dir)
                                torch.save(
                                    args,
                                    os.path.join(output_dir,
                                                 'training_args.bin'))
                                logger.info("Saving best model so far to %s",
                                            output_dir)
                            tb_writer.add_scalar('lr',
                                                 scheduler.get_lr()[0],
                                                 global_step)
                            tb_writer.add_scalar('loss',
                                                 (tr_loss - logging_loss) /
                                                 args.logging_steps,
                                                 global_step)
                            logging_loss = tr_loss

                    if args.local_rank in [
                            -1, 0
                    ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                        # Save model checkpoint
                        output_dir = os.path.join(
                            args.output_dir,
                            'checkpoint-{}'.format(global_step))
                        if not os.path.exists(output_dir):
                            os.makedirs(output_dir)
                        model_to_save = model.module if hasattr(
                            model, 'module'
                        ) else model  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        torch.save(
                            args, os.path.join(output_dir,
                                               'training_args.bin'))
                        logger.info("Saving model checkpoint to %s",
                                    output_dir)

                if args.pacing_function != "":
                    percentage_curriculum_iter = 0.90
                    curriculum_iterations = (t_total *
                                             args.percentage_data_by_epoch
                                             ) * percentage_curriculum_iter
                    new_data_fraction = min(
                        1, PACING_FUNCTIONS[args.pacing_function](
                            global_step, curriculum_iterations, c0))
                    train_data = ordered_train_dataset[
                        0:int(new_data_fraction * len(ordered_train_dataset))]
                    train_sampler = RandomSampler(
                        train_data
                    ) if args.local_rank == -1 else DistributedSampler(
                        train_data)
                    train_dataloader = DataLoader(
                        train_data,
                        sampler=train_sampler,
                        batch_size=args.train_batch_size)

                #this is needed because of the cycle we added to the train_loader
                if step == int(args.percentage_data_by_epoch *
                               (t_total / args.num_train_epochs)):
                    logger.info("Finished epoch with " + str(step) +
                                " iterations.")
                    if args.reset_clf_weights:
                        if type(model) == torch.nn.DataParallel:
                            model.module.classifier.weight.data.normal_(
                                mean=0.0, std=0.02)
                        else:
                            model.classifier.weight.data.normal_(mean=0.0,
                                                                 std=0.02)
                    break
                if args.max_steps > 0 and global_step > args.max_steps:
                    epoch_iterator.close()
                    break

                step += 1
                #end of a batch
                if args.debug_mode:
                    break
            if args.max_steps > 0 and global_step > args.max_steps:
                train_iterator.close()
                break
            #end of an epoch
            if args.debug_mode:
                break

        #end of a curriculum data shard
        if args.debug_mode:
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
def train(args, train_dataset, model, tokenizer, label_2test_array):
    """ Train the model """

    num_labels = len(label_2test_array)
    print("\nnum_labels {}\n".format(num_labels))

    if args.local_rank in [-1, 0]:
        tb_writer = SummaryWriter()

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)
    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])

    ## track best loss on eval set ??
    eval_loss = np.inf
    last_best = 0
    break_early = False

    set_seed(
        args)  # Added here for reproducibility (even between python 2 and 3)

    for epoch_counter in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            # inputs, labels, attention_mask = mask_tokens(batch, tokenizer, args) if args.mlm else (batch, batch)
            ## !!!  WE ARE NOT GOING TO TRAIN MASKED-LM

            ## also, the batch will have this ordering
            # return (torch.tensor(self.attention_mask[item]), torch.tensor(self.examples[item]),
            #       torch.LongTensor(self.label1hot[item]), torch.LongTensor(self.label_mask[item]),
            #       torch.tensor(self.token_type[item]) )

            max_len_in_batch = int(torch.max(torch.sum(
                batch[0], 1)))  ## only need max len
            attention_mask = batch[0][:, 0:max_len_in_batch].to(args.device)
            inputs = batch[1][:, 0:max_len_in_batch].to(args.device)
            labels = batch[2].to(
                args.device)  ## already in batch_size x num_label
            labels_mask = batch[3][:, 0:max_len_in_batch].to(
                args.device
            )  ## extract out labels from the array input... probably doesn't need this to be in GPU
            token_type = batch[4][:, 0:max_len_in_batch].to(args.device)

            model.train()

            # call to the @model
            # def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None,
            #   position_ids=None, head_mask=None, attention_mask_label=None):

            outputs = model(inputs,
                            token_type_ids=token_type,
                            attention_mask=attention_mask,
                            labels=labels,
                            position_ids=None,
                            attention_mask_label=labels_mask
                            )  # if args.mlm else model(inputs, labels=labels)

            loss = outputs[
                0]  # model outputs are always tuple in pytorch-transformers (see doc)

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                loss.backward()

            tr_loss += loss.item()
            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)
                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if (epoch_counter > 0) and args.local_rank in [
                        -1, 0
                ] and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, 'checkpoint-{}'.format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    model_to_save = model.module if hasattr(
                        model, 'module'
                    ) else model  # Take care of distributed/parallel training
                    model_to_save.save_pretrained(output_dir)
                    torch.save(args,
                               os.path.join(output_dir, 'training_args.bin'))
                    logger.info("Saving model checkpoint to %s", output_dir)

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer,
                                           label_2test_array)
                        for key, value in results.items():
                            tb_writer.add_scalar('eval_{}'.format(key), value,
                                                 global_step)

                    tb_writer.add_scalar('lr',
                                         scheduler.get_lr()[0], global_step)
                    tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                         args.logging_steps, global_step)
                    logging_loss = tr_loss

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break

        ## end 1 epoch
        results = evaluate(args, model, tokenizer, label_2test_array)
        if results['eval_loss'] < eval_loss:
            eval_loss = results['eval_loss']
            last_best = epoch_counter
            break_early = False
            print(
                '\nupdate lowest loss on epoch {}, {}\nreset break_early to False, see break_early variable {}'
                .format(epoch_counter, eval_loss, break_early))
        else:
            if epoch_counter - last_best > 5:  ## break counter after 5 epoch
                # break ## break early
                break_early = True
                print(
                    'epoch {} set break_early to True, see break_early variable {}'
                    .format(epoch_counter, break_early))

        if break_early:
            train_iterator.close()
            print("**** break early ****")
            break

        # print ('\neval on trainset\n')
        # true_label = np.array (true_label)
        # result = evaluation_metric.all_metrics ( np.round(prediction) , true_label, yhat_raw=prediction, k=[5,10,15,20,25]) ## we can pass vector of P@k and R@k
        # evaluation_metric.print_metrics( result )

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank in [-1, 0]:
        tb_writer.close()

    return global_step, tr_loss / global_step
Пример #13
0
def train(args, train_dataset, model, tokenizer):
    tb_writer = SummaryWriter()
    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    ## DATALOADER
    train_sampler = SequentialSampler(train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)
    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_total_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs
    logger.info("***** Running training *****")
    logger.info("  Num examples = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Batch size per GPU = %d", args.per_gpu_train_batch_size)
    logger.info("  Total train batch size = %d",
                args.train_batch_size * args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    ## OPTIMIZER
    bert_optimizer_grouped_parameters = get_bert_param_groups(model, args)
    bert_optimizer = AdamW(bert_optimizer_grouped_parameters,
                           lr=args.learning_rate,
                           eps=args.adam_epsilon,
                           weight_decay=args.weight_decay)
    scheduler = WarmupLinearSchedule(bert_optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    ## TRAIN
    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    set_seed(args)
    for _ in trange(int(args.num_train_epochs), desc='Epoch'):
        for batch in tqdm(train_dataloader,
                          desc='Iteration',
                          total=len(train_dataloader)):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)

            output = model(batch[0], batch[1], batch[2])
            logits = output[0].squeeze()

            loss = F.mse_loss(logits, batch[3])
            # loss = F.smooth_l1_loss(logits, data_a.label)
            # loss = F.l1_loss(logits, data_a.label)

            loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(),
                                           args.max_grad_norm)

            tr_loss += loss.item()
            scheduler.step()
            bert_optimizer.step()
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:

                tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                tb_writer.add_scalar('loss', (tr_loss - logging_loss) /
                                     args.logging_steps, global_step)
                logging_loss = tr_loss

            if args.save_steps > 0 and global_step % args.save_steps == 0:
                output_dir = os.path.join(args.output_dir,
                                          'checkpoint-{}'.format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                # model_to_save = model.module if hasattr(model, 'module') else model  # Take care of distributed/parallel training
                # model_to_save.save_pretrained(output_dir)
                torch.save(args, os.path.join(output_dir, 'training_args.bin'))
                # logger.info("Saving model checkpoint to %s", output_dir)

        result = evaluate(args, model, tokenizer)
    tb_writer.close()
    return global_step, tr_loss / global_step
Пример #14
0
def train(args: Union[dict, gd.FancyDict], train_dataset, model, tokenizer):
    """ Train the model """

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    train_sampler = RandomSampler(
        train_dataset) if args.local_rank == -1 else DistributedSampler(
            train_dataset)
    train_dataloader = DataLoader(train_dataset,
                                  sampler=train_sampler,
                                  batch_size=args.train_batch_size)

    if args.max_steps > 0:
        t_total = args.max_steps
        args.num_train_epochs = args.max_steps // (
            len(train_dataloader) // args.gradient_accumulation_steps) + 1
    else:
        t_total = len(
            train_dataloader
        ) // args.gradient_accumulation_steps * args.num_train_epochs

    if args.call_wandb:
        wandb.config['Total optimization steps'] = t_total

    # Prepare optimizer and schedule (linear warmup and decay)
    no_decay = ['bias', 'LayerNorm.weight']
    optimizer_grouped_parameters = [{
        'params': [
            p for n, p in model.named_parameters()
            if not any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        args.weight_decay
    }, {
        'params': [
            p for n, p in model.named_parameters()
            if any(nd in n for nd in no_decay)
        ],
        'weight_decay':
        0.0
    }]
    optimizer = AdamW(optimizer_grouped_parameters,
                      lr=args.learning_rate,
                      eps=args.adam_epsilon)
    scheduler = WarmupLinearSchedule(optimizer,
                                     warmup_steps=args.warmup_steps,
                                     t_total=t_total)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Total optimization steps = %d", len(train_dataset))
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1))
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()

    train_iterator = trange(int(args.num_train_epochs),
                            desc="Epoch",
                            disable=args.local_rank not in [-1, 0])
    set_seed(
        args)  # Added here for reproductibility (even between python 2 and 3)
    previous_accuracy = 0

    for _ in train_iterator:
        epoch_iterator = tqdm(train_dataloader,
                              desc="Iteration",
                              disable=args.local_rank not in [-1, 0])
        for step, batch in enumerate(epoch_iterator):
            model.train()
            batch = tuple(t.to(args.device) for t in batch)
            if args.mode == 'loss_in_train_loop':
                inputs = {
                    'input_ids':
                    batch[0],
                    'attention_mask':
                    batch[1],
                    'token_type_ids':
                    batch[2] if args.model_type in ['bert', 'xlnet'] else None
                }

                outputs = model(**inputs)
                logits = outputs[0]
                inputs['labels'] = batch[3]

                # outputs = (logits,) + outputs[2:]  # add hidden hiddenstates and attention if they are here

                # Calculating loss
                if inputs['labels'] is not None:
                    if args.num_labels == 1:
                        #  We are doing regression
                        loss_fct = MSELoss()
                        loss = loss_fct(logits.view(-1),
                                        inputs['labels'].view(-1))
                    else:
                        loss_fct = CrossEntropyLoss()
                        loss = loss_fct(logits.view(-1, args.num_labels),
                                        inputs['labels'].view(-1))
                    outputs = (loss, ) + outputs

                loss = outputs[
                    0]  # model outputs are always tuple in pytorch-transformers (see doc)

            elif args.mode == 'loss_in_model':
                inputs = {
                    'input_ids':
                    batch[0],
                    'attention_mask':
                    batch[1],
                    'token_type_ids':
                    batch[2] if args.model_type in ['bert', 'xlnet'] else None,
                    'targets':
                    batch[3]
                }

                loss, logits = model(**inputs)
                inputs['labels'] = batch[3]

            else:
                print(f"mode not recognized. mode found {args.mode}")
                raise IOError

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)
            iter_loss = loss.item()
            tr_loss += iter_loss

            # logging loss for wandb
            if global_step % args.logging_loss_steps == 0 and args.call_wandb:
                # log the loss here
                wandb.log({'iter_loss': iter_loss})

            if (step + 1) % args.gradient_accumulation_steps == 0:
                scheduler.step()  # Update learning rate schedule
                optimizer.step()
                if args.pruner:
                    args.pruner.step()
                model.zero_grad()
                global_step += 1

                if args.local_rank in [
                        -1, 0
                ] and args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    # Log metrics
                    if args.local_rank == -1 and args.evaluate_during_training:  # Only evaluate when single GPU otherwise metrics may not average well
                        results = evaluate(args, model, tokenizer)

                        if args.call_wandb:
                            wandb.log({k: v for k, v in results.items()})

                        for k, v in results.items():
                            if k == 'acc':
                                key = 'acc'
                            elif k == 'mcc':
                                key = 'mcc'
                            elif k == 'corr':
                                key = 'corr'
                            elif k == 'f1':
                                key = 'f1'
                            elif k == 'acc_and_f1':
                                key = 'acc_and_f1'
                            elif k == 'pearson':
                                key = 'pearson'
                            elif k == 'spearmanr':
                                key = 'spearmanr'
                            else:
                                raise gd.UnknownAccuracyMetric(
                                    f"The current training loop only"
                                    f" supports acc, mcc, corr, acc_and_f1, f1, pearson,"
                                    f" and spearmanr"
                                    f". Found {k}")

                        if previous_accuracy < results[key]:  # acc, mcc, corr.
                            # Note that previous accuracy could be acc, mrr, corr
                            previous_accuracy = results[key]
                            if args.call_wandb:
                                wandb.log({'best_acc': previous_accuracy})
                            # save the model here
                            if args.save:
                                gd.save_model(model=model,
                                              output_dir=args.output_dir,
                                              model_name=args.task_name +
                                              args.output_name,
                                              accuracy=results[key],
                                              config={"mode": args.mode})

                        # for key, value in results.items():
                        #     tb_writer.add_scalar('eval_{}'.format(key), value, global_step)
                    if args.call_wandb:
                        wandb.log({'lr': scheduler.get_lr()[0]})
                        wandb.log({'loss': tr_loss - logging_loss})

                    # tb_writer.add_scalar('lr', scheduler.get_lr()[0], global_step)
                    # tb_writer.add_scalar('loss', (tr_loss - logging_loss) / args.logging_steps, global_step)
                    logging.info(
                        f"the current training loss is {tr_loss - logging_loss}"
                    )
                    logging_loss = tr_loss

            if args.max_steps > 0 and global_step > args.max_steps:
                epoch_iterator.close()
                break
        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.call_wandb:
        wandb.config['global_step'] = global_step
        wandb.config['global_loss'] = tr_loss / global_step
    return global_step, tr_loss / global_step